chore: merge upstream v0.1.124-125, keep Windsurf/Antigravity customizations

Upstream changes:
- feat: 邮箱 + GitHub + Google OAuth 快捷登录
- feat: Codex image bridge 开关
- feat: 内容审核 (content moderation) — 新增 contentModerationService/Handler
- feat: redeem code 返利、批量并发 API、markdown 页面渲染
- feat: 登录注册条款确认
- fix(security): pages API 加 JWT + 可见性校验
- fix: 修复 markdown 页面图片路径
- fix(gateway): 不再默认注入 redact thinking beta
- fix: 稳定 anthropic passthrough 超时错误
- chore: VERSION 升到 0.1.125 + golang:1.26.3-alpine

Conflict resolutions:
- Dockerfile/backend/Dockerfile: 取 upstream golang:1.26.3-alpine
- backend/go.mod: 取 upstream term v0.42.0,保留定制 protobuf v1.36.10
- frontend/src/api/admin/index.ts: 并集 (windsurf + riskControl)
- backend/cmd/server/wire_gen.go: 接 upstream contentModeration*,保留 windsurfHandler/windsurfGatewayService/billingCacheService/requestEventBus;并通过 wire 重生成
- frontend/src/views/admin/AccountsView.vue: 采用 upstream 双层布局 + OpenAI Meta,保留 is_enterprise prop 和 Windsurf tier badge

Note:
- WIP commit (de048fad) preserved Windsurf tier access service / NLU
  extractor / ops log stream / Google OAuth login modal et al before merge.
- 3 pre-existing go vet issues in test files (NewOpsHandler, RegisterGatewayRoutes,
  DefaultCLIProductVersion) are unrelated to this merge — leftover from local
  customization refactors; production code (go build ./...) passes.
This commit is contained in:
win 2026-05-09 01:42:39 +08:00
commit 7347dfffc1
136 changed files with 15502 additions and 425 deletions

View File

@ -20,7 +20,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
go version | grep -q 'go1.26.3'
- name: Unit tests
working-directory: backend
run: make test-unit
@ -60,7 +60,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
go version | grep -q 'go1.26.3'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:

View File

@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
go version | grep -q 'go1.26.3'
# Docker setup for GoReleaser
- name: Set up QEMU

View File

@ -23,7 +23,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
go version | grep -q 'go1.26.3'
- name: Run govulncheck
working-directory: backend
run: |

View File

@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26-alpine
ARG GOLANG_IMAGE=golang:1.26.3-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct

View File

@ -72,8 +72,8 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
</tr>
<tr>

View File

@ -71,8 +71,8 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/register?aff=SUB2API">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
</tr>
<tr>

View File

@ -71,8 +71,8 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
</tr>
<tr>

View File

@ -1,4 +1,4 @@
FROM golang:1.26-alpine
FROM golang:1.26.3-alpine
WORKDIR /app

View File

@ -1 +1 @@
0.1.123
0.1.125

View File

@ -8,11 +8,6 @@ package main
import (
"context"
"log"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
@ -23,9 +18,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"log"
"net/http"
"sync"
"time"
)
import (
_ "embed"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
)
@ -74,7 +74,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
@ -136,26 +136,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
windsurfLSService := service.ProvideWindsurfLSService(configConfig)
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
@ -237,19 +236,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
contentModerationRepository := repository.NewContentModerationRepository(db)
contentModerationHashCache := repository.NewContentModerationHashCache(redisClient)
contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService)
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService)
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository)
windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository)
windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService)
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, windsurfHandler, affiliateHandler)
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
@ -274,6 +277,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
application := &Application{
@ -485,6 +489,12 @@ func provideCleanup(
}
return nil
}},
{"WindsurfLSService", func() error {
if windsurfLS != nil {
windsurfLS.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@ -16,6 +16,8 @@ import (
var authProviderTypes = map[string]struct{}{
"email": {},
"github": {},
"google": {},
"linuxdo": {},
"oidc": {},
"wechat": {},

View File

@ -83,10 +83,10 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("github"))
require.Error(t, validator("unknown"))
}
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {

View File

@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
field.String("signup_source").
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc":
case "email", "linuxdo", "wechat", "oidc", "github", "google":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
}
}).
Default("email"),

View File

@ -1,6 +1,6 @@
module github.com/Wei-Shaw/sub2api
go 1.26.2
go 1.26.3
require (
connectrpc.com/connect v1.19.2
@ -21,6 +21,7 @@ require (
github.com/google/wire v0.7.0
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/klauspost/compress v1.18.2
github.com/lib/pq v1.10.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0
@ -40,11 +41,11 @@ require (
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.49.0
golang.org/x/crypto v0.50.0
golang.org/x/image v0.39.0
golang.org/x/net v0.52.0
golang.org/x/net v0.53.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
golang.org/x/term v0.42.0
google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
@ -112,7 +113,6 @@ require (
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
@ -176,7 +176,7 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect

View File

@ -187,8 +187,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@ -224,8 +222,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -259,8 +255,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -290,8 +284,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@ -324,8 +316,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@ -417,16 +407,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -438,10 +428,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=

View File

@ -79,6 +79,8 @@ type Config struct {
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
@ -248,6 +250,19 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type EmailOAuthProviderConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
EmailsURL string `mapstructure:"emails_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
}
const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
@ -619,6 +634,9 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// CodexImageGenerationBridgeEnabled: 是否为 Codex `/v1/responses` 自动注入 image_generation 工具和桥接指令。
// 默认关闭,避免纯文本 Codex 请求被意外改写;显式携带 image_generation 工具的请求仍按分组能力转发。
CodexImageGenerationBridgeEnabled bool `mapstructure:"codex_image_generation_bridge_enabled"`
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
// 模板渲染后会直接覆盖最终 instructions若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
@ -1773,6 +1791,7 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.codex_image_generation_bridge_enabled", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,344 @@
package admin
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
)
func TestParseCodexSessionImportEntriesSupportsRawTokenJSONAndArray(t *testing.T) {
token1 := "raw-access-token-1"
token2 := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
"email": "json@example.com",
})
token3 := "raw-access-token-3"
req := CodexSessionImportRequest{
Content: fmt.Sprintf("%s\n{\"accessToken\":%q}\n[%q]", token1, token2, token3),
}
entries, err := parseCodexSessionImportEntries(req)
if err != nil {
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
}
if len(entries) != 3 {
t.Fatalf("len(entries) = %d, want 3", len(entries))
}
first, err := normalizeCodexImportEntry(entries[0])
if err != nil {
t.Fatalf("normalize raw token error = %v", err)
}
if first.Credentials["access_token"] != token1 {
t.Fatalf("raw token access_token = %v, want %s", first.Credentials["access_token"], token1)
}
second, err := normalizeCodexImportEntry(entries[1])
if err != nil {
t.Fatalf("normalize json token error = %v", err)
}
if second.Email != "json@example.com" {
t.Fatalf("email = %q, want json@example.com", second.Email)
}
third, err := normalizeCodexImportEntry(entries[2])
if err != nil {
t.Fatalf("normalize array token error = %v", err)
}
if third.Credentials["access_token"] != token3 {
t.Fatalf("array token access_token = %v, want %s", third.Credentials["access_token"], token3)
}
}
func TestParseCodexSessionImportEntriesFallsBackToLineModeForMixedJSONAndToken(t *testing.T) {
req := CodexSessionImportRequest{
Content: "{\"accessToken\":\"json-line-token\"}\nraw-line-token",
}
entries, err := parseCodexSessionImportEntries(req)
if err != nil {
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
}
if len(entries) != 2 {
t.Fatalf("len(entries) = %d, want 2", len(entries))
}
first, err := normalizeCodexImportEntry(entries[0])
if err != nil {
t.Fatalf("normalize json line error = %v", err)
}
if first.Credentials["access_token"] != "json-line-token" {
t.Fatalf("json line access_token = %v, want json-line-token", first.Credentials["access_token"])
}
second, err := normalizeCodexImportEntry(entries[1])
if err != nil {
t.Fatalf("normalize raw line error = %v", err)
}
if second.Credentials["access_token"] != "raw-line-token" {
t.Fatalf("raw line access_token = %v, want raw-line-token", second.Credentials["access_token"])
}
}
func TestNormalizeCodexSessionJSONExtractsCredentialsAndIgnoresSessionToken(t *testing.T) {
accessToken := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
"email": "claim@example.com",
"https://api.openai.com/auth": map[string]any{
"chatgpt_account_id": "acct-from-claim",
"chatgpt_user_id": "user-from-claim",
"chatgpt_plan_type": "plus",
"poid": "org-from-claim",
},
})
raw := map[string]any{
"user": map[string]any{
"id": "user-from-json",
"name": "Sup OO",
"email": "json@example.com",
"image": "https://example.com/avatar.png",
},
"account": map[string]any{
"id": "acct-from-json",
"planType": "free",
},
"accessToken": accessToken,
"sessionToken": "secret-session-token",
"expires": "2026-08-05T13:40:42.836Z",
}
item, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: raw})
if err != nil {
t.Fatalf("normalizeCodexImportEntry error = %v", err)
}
if item.Credentials["access_token"] != accessToken {
t.Fatalf("access_token not stored")
}
if item.Credentials["email"] != "json@example.com" {
t.Fatalf("email = %v, want json@example.com", item.Credentials["email"])
}
if item.Credentials["chatgpt_account_id"] != "acct-from-json" {
t.Fatalf("chatgpt_account_id = %v, want acct-from-json", item.Credentials["chatgpt_account_id"])
}
if item.Credentials["chatgpt_user_id"] != "user-from-json" {
t.Fatalf("chatgpt_user_id = %v, want user-from-json", item.Credentials["chatgpt_user_id"])
}
if item.Credentials["plan_type"] != "free" {
t.Fatalf("plan_type = %v, want free", item.Credentials["plan_type"])
}
if _, ok := item.Credentials["session_token"]; ok {
t.Fatalf("session_token should not be written to credentials")
}
if item.Extra["session_token_present"] != true {
t.Fatalf("session_token_present = %v, want true", item.Extra["session_token_present"])
}
if item.Extra["session_expires_at"] != "2026-08-05T13:40:42Z" {
t.Fatalf("session_expires_at = %v", item.Extra["session_expires_at"])
}
if item.TokenExpiresAt == nil {
t.Fatalf("TokenExpiresAt should be parsed from accessToken")
}
}
func TestMergeCodexImportCredentialsClearsStaleRefreshFieldsWhenIncomingHasNoRefreshToken(t *testing.T) {
existing := map[string]any{
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"client_id": "old-client-id",
"id_token": "old-id-token",
"model_mapping": map[string]any{"from": "existing"},
"chatgpt_account_id": "acct-old",
"unrelated_existing": "keep",
}
incoming := map[string]any{
"access_token": "new-access-token",
"expires_at": "2026-08-05T13:40:42Z",
"chatgpt_account_id": "acct-new",
}
item := &codexImportAccount{
AccessToken: "new-access-token",
}
merged := mergeCodexImportCredentials(existing, incoming, item)
if merged["access_token"] != "new-access-token" {
t.Fatalf("access_token = %v, want new-access-token", merged["access_token"])
}
if merged["chatgpt_account_id"] != "acct-new" {
t.Fatalf("chatgpt_account_id = %v, want acct-new", merged["chatgpt_account_id"])
}
if _, ok := merged["refresh_token"]; ok {
t.Fatalf("refresh_token should be cleared")
}
if _, ok := merged["client_id"]; ok {
t.Fatalf("client_id should be cleared")
}
if _, ok := merged["id_token"]; ok {
t.Fatalf("id_token should be cleared")
}
if merged["unrelated_existing"] != "keep" {
t.Fatalf("unrelated_existing = %v, want keep", merged["unrelated_existing"])
}
if _, ok := merged["model_mapping"]; !ok {
t.Fatalf("model_mapping should be preserved")
}
}
func TestMergeCodexImportCredentialsKeepsRefreshFieldsWhenIncomingHasRefreshToken(t *testing.T) {
existing := map[string]any{
"refresh_token": "old-refresh-token",
"client_id": "old-client-id",
"id_token": "old-id-token",
}
incoming := map[string]any{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"client_id": "new-client-id",
"id_token": "new-id-token",
}
item := &codexImportAccount{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
IDToken: "new-id-token",
}
merged := mergeCodexImportCredentials(existing, incoming, item)
if merged["refresh_token"] != "new-refresh-token" {
t.Fatalf("refresh_token = %v, want new-refresh-token", merged["refresh_token"])
}
if merged["client_id"] != "new-client-id" {
t.Fatalf("client_id = %v, want new-client-id", merged["client_id"])
}
if merged["id_token"] != "new-id-token" {
t.Fatalf("id_token = %v, want new-id-token", merged["id_token"])
}
}
func TestNormalizeCodexImportRejectsExpiredAccessToken(t *testing.T) {
expiredToken := buildCodexImportTestJWT(t, time.Now().Add(-time.Hour), map[string]any{})
_, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: expiredToken})
if err == nil {
t.Fatal("normalizeCodexImportEntry error = nil, want expired token error")
}
if !strings.Contains(err.Error(), "已过期") {
t.Fatalf("error = %v, want expired token message", err)
}
}
func TestResolveCodexImportExpiryForNoRefreshTokenUsesTokenExpiry(t *testing.T) {
tokenExpiresAt := time.Now().Add(time.Hour).UTC()
item := &codexImportAccount{
AccessToken: "access-token",
Credentials: map[string]any{"access_token": "access-token"},
TokenExpiresAt: &tokenExpiresAt,
WarningTexts: []string{},
}
disabled := false
req := CodexSessionImportRequest{AutoPauseOnExpired: &disabled}
accountExpiresAt, credentialExpiresAt, autoPause, warnings, err := resolveCodexImportExpiry(req, item)
if err != nil {
t.Fatalf("resolveCodexImportExpiry error = %v", err)
}
if accountExpiresAt == nil || *accountExpiresAt != tokenExpiresAt.Unix() {
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, tokenExpiresAt.Unix())
}
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != tokenExpiresAt.Unix() {
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, tokenExpiresAt)
}
if autoPause == nil || !*autoPause {
t.Fatalf("autoPause = %v, want true", autoPause)
}
if len(warnings) == 0 {
t.Fatalf("warnings should not be empty")
}
}
func TestResolveCodexImportExpiryForNoRefreshTokenRequiresExpiry(t *testing.T) {
item := &codexImportAccount{
AccessToken: "opaque-access-token",
Credentials: map[string]any{"access_token": "opaque-access-token"},
WarningTexts: []string{},
}
_, _, _, _, err := resolveCodexImportExpiry(CodexSessionImportRequest{}, item)
if err == nil {
t.Fatal("resolveCodexImportExpiry error = nil, want missing expiry error")
}
if !strings.Contains(err.Error(), "无法解析 accessToken 过期时间") {
t.Fatalf("error = %v, want missing expiry message", err)
}
}
func TestResolveCodexImportExpiryForNoRefreshTokenUsesEarlierRequestExpiry(t *testing.T) {
tokenExpiresAt := time.Now().Add(2 * time.Hour).UTC()
requestExpiresAt := time.Now().Add(time.Hour).UTC()
item := &codexImportAccount{
AccessToken: "access-token",
Credentials: map[string]any{"access_token": "access-token"},
TokenExpiresAt: &tokenExpiresAt,
WarningTexts: []string{},
}
reqUnix := requestExpiresAt.Unix()
req := CodexSessionImportRequest{ExpiresAt: &reqUnix}
accountExpiresAt, credentialExpiresAt, _, _, err := resolveCodexImportExpiry(req, item)
if err != nil {
t.Fatalf("resolveCodexImportExpiry error = %v", err)
}
if accountExpiresAt == nil || *accountExpiresAt != requestExpiresAt.Unix() {
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, requestExpiresAt.Unix())
}
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != requestExpiresAt.Unix() {
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, requestExpiresAt)
}
}
func TestCodexIdentityKeysPreferStrongIdentifiers(t *testing.T) {
keys := buildCodexIdentityKeys("acct-1", "user-1", "same@example.com", "token")
for _, key := range keys {
if strings.HasPrefix(key, "email:") {
t.Fatalf("strong identity should not include email fallback: %v", keys)
}
}
keys = buildCodexIdentityKeys("", "", "same@example.com", "token")
hasEmail := false
for _, key := range keys {
if key == "email:same@example.com" {
hasEmail = true
}
}
if !hasEmail {
t.Fatalf("weak identity should include email fallback: %v", keys)
}
}
func buildCodexImportTestJWT(t *testing.T, exp time.Time, extraClaims map[string]any) string {
t.Helper()
header := map[string]any{
"alg": "none",
"typ": "JWT",
}
claims := map[string]any{
"sub": "user-from-sub",
"exp": exp.Unix(),
"iat": time.Now().Unix(),
}
for k, v := range extraClaims {
claims[k] = v
}
headerBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("marshal header: %v", err)
}
claimBytes, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
return base64.RawURLEncoding.EncodeToString(headerBytes) + "." + base64.RawURLEncoding.EncodeToString(claimBytes) + "."
}

View File

@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil
}
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
return len(userIDs), nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}

View File

@ -0,0 +1,238 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type ContentModerationHandler struct {
service *service.ContentModerationService
}
func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler {
return &ContentModerationHandler{service: svc}
}
type contentModerationConfigRequest struct {
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
}
type contentModerationAPIKeyTestRequest struct {
APIKeys []string `json:"api_keys"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
TimeoutMS int `json:"timeout_ms"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
}
type contentModerationHashRequest struct {
InputHash string `json:"input_hash"`
}
func (h *ContentModerationHandler) GetConfig(c *gin.Context) {
cfg, err := h.service.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
var req contentModerationConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{
Enabled: req.Enabled,
Mode: req.Mode,
BaseURL: req.BaseURL,
Model: req.Model,
APIKey: req.APIKey,
APIKeys: req.APIKeys,
APIKeysMode: req.APIKeysMode,
DeleteAPIKeyHashes: req.DeleteAPIKeyHashes,
ClearAPIKey: req.ClearAPIKey,
TimeoutMS: req.TimeoutMS,
SampleRate: req.SampleRate,
AllGroups: req.AllGroups,
GroupIDs: req.GroupIDs,
RecordNonHits: req.RecordNonHits,
WorkerCount: req.WorkerCount,
QueueSize: req.QueueSize,
BlockStatus: req.BlockStatus,
BlockMessage: req.BlockMessage,
EmailOnHit: req.EmailOnHit,
AutoBanEnabled: req.AutoBanEnabled,
BanThreshold: req.BanThreshold,
ViolationWindowHours: req.ViolationWindowHours,
RetryCount: req.RetryCount,
HitRetentionDays: req.HitRetentionDays,
NonHitRetentionDays: req.NonHitRetentionDays,
PreHashCheckEnabled: req.PreHashCheckEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) {
var req contentModerationAPIKeyTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{
APIKeys: req.APIKeys,
BaseURL: req.BaseURL,
Model: req.Model,
TimeoutMS: req.TimeoutMS,
Prompt: req.Prompt,
Images: req.Images,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) GetStatus(c *gin.Context) {
status, err := h.service.GetStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, status)
}
func (h *ContentModerationHandler) ListLogs(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := service.ContentModerationLogFilter{
Pagination: pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortOrder: pagination.SortOrderDesc,
},
Result: c.Query("result"),
Endpoint: c.Query("endpoint"),
Search: c.Query("search"),
}
if raw := strings.TrimSpace(c.Query("group_id")); raw != "" {
groupID, err := strconv.ParseInt(raw, 10, 64)
if err != nil || groupID <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &groupID
}
if raw := strings.TrimSpace(c.Query("from")); raw != "" {
t, _, err := parseContentModerationDate(raw)
if err != nil {
response.BadRequest(c, "Invalid from")
return
}
filter.From = &t
}
if raw := strings.TrimSpace(c.Query("to")); raw != "" {
t, dateOnly, err := parseContentModerationDate(raw)
if err != nil {
response.BadRequest(c, "Invalid to")
return
}
if dateOnly {
t = t.Add(24*time.Hour - time.Nanosecond)
}
filter.To = &t
}
items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize)
}
func (h *ContentModerationHandler) UnbanUser(c *gin.Context) {
userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
result, err := h.service.UnbanUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) {
var req contentModerationHashRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) {
result, err := h.service.ClearFlaggedInputHashes(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func parseContentModerationDate(raw string) (time.Time, bool, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return time.Time{}, false, nil
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return t, false, nil
}
t, err := time.Parse("2006-01-02", raw)
return t, err == nil, err
}

View File

@ -117,6 +117,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
LoginAgreementEnabled: settings.LoginAgreementEnabled,
LoginAgreementMode: settings.LoginAgreementMode,
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@ -169,6 +173,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GitHubOAuthClientID: settings.GitHubOAuthClientID,
GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured,
GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
GoogleOAuthClientID: settings.GoogleOAuthClientID,
GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured,
GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
@ -185,6 +199,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
RiskControlEnabled: settings.RiskControlEnabled,
AffiliateRebateRate: settings.AffiliateRebateRate,
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
@ -294,17 +309,50 @@ func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.O
return &service.OpenAIFastPolicySettings{Rules: rules}
}
func loginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
result := make([]dto.LoginAgreementDocument, 0, len(items))
for _, item := range items {
result = append(result, dto.LoginAgreementDocument{
ID: item.ID,
Title: item.Title,
ContentMD: item.ContentMD,
})
}
return result
}
func loginAgreementDocumentsToService(items []dto.LoginAgreementDocument) []service.LoginAgreementDocument {
result := make([]service.LoginAgreementDocument, 0, len(items))
for _, item := range items {
title := strings.TrimSpace(item.Title)
content := strings.TrimSpace(item.ContentMD)
if title == "" && content == "" {
continue
}
result = append(result, service.LoginAgreementDocument{
ID: strings.TrimSpace(item.ID),
Title: title,
ContentMD: content,
})
}
return result
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
LoginAgreementMode string `json:"login_agreement_mode"`
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
LoginAgreementDocuments []dto.LoginAgreementDocument `json:"login_agreement_documents"`
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@ -368,6 +416,17 @@ type UpdateSettingsRequest struct {
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GitHubOAuthClientID string `json:"github_oauth_client_id"`
GitHubOAuthClientSecret string `json:"github_oauth_client_secret"`
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
GoogleOAuthClientID string `json:"google_oauth_client_id"`
GoogleOAuthClientSecret string `json:"google_oauth_client_secret"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
@ -413,6 +472,16 @@ type UpdateSettingsRequest struct {
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
@ -497,6 +566,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
// 风控中心功能开关
RiskControlEnabled *bool `json:"risk_control_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
@ -633,6 +705,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
}
loginAgreementMode := strings.ToLower(strings.TrimSpace(req.LoginAgreementMode))
if loginAgreementMode == "" {
loginAgreementMode = strings.ToLower(strings.TrimSpace(previousSettings.LoginAgreementMode))
}
switch loginAgreementMode {
case "", "modal":
loginAgreementMode = "modal"
case "checkbox":
default:
response.BadRequest(c, "Login agreement mode must be modal or checkbox")
return
}
loginAgreementUpdatedAt := strings.TrimSpace(req.LoginAgreementUpdatedAt)
if loginAgreementUpdatedAt == "" {
loginAgreementUpdatedAt = strings.TrimSpace(previousSettings.LoginAgreementUpdatedAt)
}
loginAgreementDocuments := loginAgreementDocumentsToService(req.LoginAgreementDocuments)
if len(loginAgreementDocuments) == 0 {
loginAgreementDocuments = previousSettings.LoginAgreementDocuments
}
for _, doc := range loginAgreementDocuments {
if strings.TrimSpace(doc.Title) == "" {
response.BadRequest(c, "Login agreement document title is required")
return
}
if len(doc.Title) > 80 {
response.BadRequest(c, "Login agreement document title is too long (max 80 characters)")
return
}
if len(doc.ContentMD) > 200*1024 {
response.BadRequest(c, "Login agreement document content is too large (max 200KB)")
return
}
}
if req.LoginAgreementEnabled && len(loginAgreementDocuments) == 0 {
response.BadRequest(c, "Login agreement documents are required when enabled")
return
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
@ -994,17 +1104,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
return
}
if strings.TrimSpace(item.URL) == "" {
response.BadRequest(c, "Custom menu item URL is required")
return
}
if len(item.URL) > maxMenuItemURLLen {
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
return
urlTrimmed := strings.TrimSpace(item.URL)
if strings.HasPrefix(urlTrimmed, "md:") {
// Markdown page mode: URL = "md:<slug>"
slug := strings.TrimPrefix(urlTrimmed, "md:")
if slug == "" {
response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)")
return
}
} else {
if urlTrimmed == "" {
response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)")
return
}
if len(item.URL) > maxMenuItemURLLen {
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil {
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:<slug>")
return
}
}
if item.Visibility != "user" && item.Visibility != "admin" {
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
@ -1148,6 +1268,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@ -1200,6 +1324,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
@ -1365,6 +1499,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AffiliateEnabled
}(),
RiskControlEnabled: func() bool {
if req.RiskControlEnabled != nil {
return *req.RiskControlEnabled
}
return previousSettings.RiskControlEnabled
}(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
@ -1396,6 +1536,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
},
GitHub: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
},
Google: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
@ -1486,6 +1640,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
LoginAgreementEnabled: updatedSettings.LoginAgreementEnabled,
LoginAgreementMode: updatedSettings.LoginAgreementMode,
LoginAgreementUpdatedAt: updatedSettings.LoginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocumentsToDTO(updatedSettings.LoginAgreementDocuments),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@ -1538,6 +1696,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled,
GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID,
GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured,
GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled,
GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID,
GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured,
GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
@ -1616,6 +1784,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
AffiliateEnabled: updatedSettings.AffiliateEnabled,
RiskControlEnabled: updatedSettings.RiskControlEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
@ -1685,6 +1855,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.LoginAgreementEnabled != after.LoginAgreementEnabled {
changed = append(changed, "login_agreement_enabled")
}
if before.LoginAgreementMode != after.LoginAgreementMode {
changed = append(changed, "login_agreement_mode")
}
if before.LoginAgreementUpdatedAt != after.LoginAgreementUpdatedAt {
changed = append(changed, "login_agreement_updated_at")
}
if !equalLoginAgreementDocuments(before.LoginAgreementDocuments, after.LoginAgreementDocuments) {
changed = append(changed, "login_agreement_documents")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
@ -2004,6 +2186,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateEnabled != after.AffiliateEnabled {
changed = append(changed, "affiliate_enabled")
}
if before.RiskControlEnabled != after.RiskControlEnabled {
changed = append(changed, "risk_control_enabled")
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
@ -2027,6 +2212,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
{name: "oidc", before: before.OIDC, after: after.OIDC},
{name: "wechat", before: before.WeChat, after: after.WeChat},
{name: "github", before: before.GitHub, after: after.GitHub},
{name: "google", before: before.Google, after: after.Google},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
@ -2141,6 +2328,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance
data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency
data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions
data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup
data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind
data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance
data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
@ -2170,6 +2367,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
return true
}
func equalLoginAgreementDocuments(a, b []service.LoginAgreementDocument) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].ID != b[i].ID || a[i].Title != b[i].Title || a[i].ContentMD != b[i].ContentMD {
return false
}
}
return true
}
func equalIntSlice(a, b []int) bool {
if len(a) != len(b) {
return false

View File

@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
response.Success(c, status)
}
// BatchUpdateConcurrency 批量修改用户并发数
// POST /api/v1/admin/users/batch-concurrency
type BatchUpdateConcurrencyRequest struct {
UserIDs []int64 `json:"user_ids"`
All bool `json:"all"`
Concurrency int `json:"concurrency"`
Mode string `json:"mode" binding:"required,oneof=set add"`
}
func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
var req BatchUpdateConcurrencyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.All && len(req.UserIDs) == 0 {
response.BadRequest(c, "user_ids is required unless all=true")
return
}
if len(req.UserIDs) > 500 {
response.BadRequest(c, "user_ids cannot exceed 500")
return
}
var userIDs []int64
if req.All {
// Fetch all user IDs via pagination
page := 1
const pageSize = 500
for {
users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
}
for _, u := range users {
userIDs = append(userIDs, u.ID)
}
if len(users) < pageSize {
break
}
page++
}
} else {
userIDs = req.UserIDs
}
if len(userIDs) == 0 {
response.Success(c, gin.H{"affected": 0})
return
}
affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"affected": affected})
}

View File

@ -0,0 +1,621 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
emailOAuthCookiePath = "/api/v1/auth/oauth"
emailOAuthStateCookieName = "email_oauth_state"
emailOAuthRedirectCookie = "email_oauth_redirect"
emailOAuthProviderCookie = "email_oauth_provider"
emailOAuthAffiliateCookie = "email_oauth_affiliate"
emailOAuthCookieMaxAgeSec = 10 * 60
emailOAuthDefaultRedirect = "/dashboard"
)
type emailOAuthTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type emailOAuthProfile struct {
Subject string
Email string
EmailVerified bool
Username string
DisplayName string
AvatarURL string
Metadata map[string]any
}
func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") }
func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") }
func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") }
func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") }
func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) {
h.completeEmailOAuthRegistration(c, "github")
}
func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) {
h.completeEmailOAuthRegistration(c, "google")
}
func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) {
cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider)
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = emailOAuthDefaultRedirect
}
secureCookie := isRequestHTTPS(c)
emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie)
emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie)
emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie)
if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" {
emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie)
} else {
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
}
authURL, err := buildEmailOAuthAuthorizeURL(cfg, state)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) {
cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider)
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = "/auth/oauth/callback"
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie)
emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie)
emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie)
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName)
if err != nil || expectedState == "" || expectedState != state {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie)
if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "")
return
}
redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = emailOAuthDefaultRedirect
}
tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code)
if err != nil {
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error()))
return
}
profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp)
if err != nil {
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error()))
return
}
h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile)
}
func (h *AuthHandler) emailOAuthCallbackWithProfile(
c *gin.Context,
provider string,
cfg config.EmailOAuthProviderConfig,
frontendCallback string,
redirectTo string,
profile *emailOAuthProfile,
) {
input := service.EmailOAuthIdentityInput{
ProviderType: provider,
ProviderKey: provider,
ProviderSubject: profile.Subject,
Email: profile.Email,
EmailVerified: profile.EmailVerified,
Username: profile.Username,
DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL,
UpstreamMetadata: profile.Metadata,
}
affiliateCode := h.emailOAuthAffiliateCode(c)
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
return
} else if shouldCreate {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
client := h.entClient()
if client == nil {
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
ProviderType: strings.TrimSpace(input.ProviderType),
ProviderKey: strings.TrimSpace(input.ProviderKey),
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
})
if err != nil {
return false, err
}
email := strings.TrimSpace(strings.ToLower(input.Email))
if identityUser != nil {
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
}
return false, nil
}
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
if errors.Is(err, service.ErrUserNotFound) {
return true, nil
}
return false, err
}
return false, nil
}
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
if c == nil {
return ""
}
if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil {
return strings.TrimSpace(code)
}
return ""
}
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
c *gin.Context,
provider string,
frontendCallback string,
redirectTo string,
profile *emailOAuthProfile,
) error {
if h == nil || profile == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
browserSessionKey, err := generateOAuthPendingBrowserSession()
if err != nil {
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c))
email := strings.TrimSpace(strings.ToLower(profile.Email))
username := strings.TrimSpace(profile.Username)
affiliateCode := h.emailOAuthAffiliateCode(c)
upstreamClaims := map[string]any{
"email": email,
"email_verified": profile.EmailVerified,
"username": username,
"provider": provider,
"provider_key": provider,
"provider_subject": strings.TrimSpace(profile.Subject),
}
if strings.TrimSpace(profile.DisplayName) != "" {
upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName)
}
if strings.TrimSpace(profile.AvatarURL) != "" {
upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL)
}
if affiliateCode != "" {
upstreamClaims["aff_code"] = affiliateCode
}
for key, value := range profile.Metadata {
if _, exists := upstreamClaims[key]; !exists {
upstreamClaims[key] = value
}
}
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
pendingError := "registration_completion_required"
choiceReason := "registration_completion_required"
if invitationRequired {
pendingError = "invitation_required"
choiceReason = "invitation_required"
}
completionResponse := map[string]any{
"step": oauthPendingChoiceStep,
"error": pendingError,
"choice_reason": choiceReason,
"adoption_required": false,
"create_account_allowed": true,
"existing_account_bindable": false,
"force_email_on_signup": true,
"invitation_required": invitationRequired,
"email": email,
"resolved_email": email,
"provider": provider,
"redirect": redirectTo,
}
if strings.TrimSpace(frontendCallback) != "" {
completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback)
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: completionResponse,
})
}
type completeEmailOAuthRequest struct {
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
}
func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) {
var req completeEmailOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
affiliateCode := strings.TrimSpace(req.AffCode)
if affiliateCode == "" {
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
}
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
c.Request.Context(),
strings.TrimSpace(session.ResolvedEmail),
req.Password,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
tx, err := client.Tx(c.Request.Context())
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
sessionForBinding := *session
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
if strings.TrimSpace(req.InvitationCode) != "" {
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
if err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
respondPendingOAuthBindingApplyError(c, err)
return
}
if err := h.authService.FinalizeOAuthEmailAccount(
txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
affiliateCode,
); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := tx.Commit(); err != nil {
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider)
}
return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", cfg.RedirectURL)
q.Set("state", state)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetFormData(map[string]string{
"grant_type": "authorization_code",
"client_id": cfg.ClientID,
"client_secret": cfg.ClientSecret,
"code": code,
"redirect_uri": cfg.RedirectURL,
}).
Post(cfg.TokenURL)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
var tokenResp emailOAuthTokenResponse
if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil {
return nil, err
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, errors.New("missing access_token")
}
return &tokenResp, nil
}
func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetBearerAuthToken(token.AccessToken).
SetHeader("Accept", "application/json").
Get(cfg.UserInfoURL)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
return parseGitHubOAuthProfile(ctx, cfg, token, resp.String())
case "google":
return parseGoogleOAuthProfile(resp.String())
default:
return nil, errors.New("unsupported oauth provider")
}
}
func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) {
subject := strings.TrimSpace(gjson.Get(body, "id").String())
if subject == "" {
return nil, errors.New("github user id is missing")
}
email := ""
emailsURL := strings.TrimSpace(cfg.EmailsURL)
if emailsURL == "" {
return nil, errors.New("github verified email is missing")
}
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
if err != nil {
return nil, err
}
email = verifiedEmail
if email == "" {
return nil, errors.New("github verified email is missing")
}
login := strings.TrimSpace(gjson.Get(body, "login").String())
name := strings.TrimSpace(gjson.Get(body, "name").String())
return &emailOAuthProfile{
Subject: subject,
Email: email,
EmailVerified: true,
Username: firstNonEmpty(login, name, "github_"+subject),
DisplayName: firstNonEmpty(name, login),
AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()),
Metadata: map[string]any{
"login": login,
},
}, nil
}
func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) {
resp, err := req.C().
R().
SetContext(ctx).
SetBearerAuthToken(accessToken).
SetHeader("Accept", "application/json").
Get(emailsURL)
if err != nil {
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
}
items := gjson.Parse(resp.String()).Array()
for _, item := range items {
if item.Get("primary").Bool() && item.Get("verified").Bool() {
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
return email, nil
}
}
}
for _, item := range items {
if item.Get("verified").Bool() {
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
return email, nil
}
}
}
return "", errors.New("github verified email is missing")
}
func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) {
subject := strings.TrimSpace(gjson.Get(body, "sub").String())
email := strings.TrimSpace(gjson.Get(body, "email").String())
verified := gjson.Get(body, "email_verified").Bool()
if subject == "" {
return nil, errors.New("google subject is missing")
}
if email == "" || !verified {
return nil, errors.New("google verified email is missing")
}
name := strings.TrimSpace(gjson.Get(body, "name").String())
return &emailOAuthProfile{
Subject: subject,
Email: email,
EmailVerified: true,
Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email),
DisplayName: name,
AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()),
Metadata: map[string]any{
"email_verified": true,
},
}, nil
}
func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: emailOAuthCookiePath,
MaxAge: emailOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func emailOAuthClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: emailOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}

View File

@ -0,0 +1,414 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
state := "github-oauth-state"
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil)
req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)})
req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")})
req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")})
c.Request = req
profile := &emailOAuthProfile{
Subject: "github-123",
Email: "fresh@example.com",
EmailVerified: true,
Username: "fresh",
DisplayName: "Fresh User",
AvatarURL: "https://cdn.example/fresh.png",
Metadata: map[string]any{
"login": "fresh",
},
}
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "github-client",
ClientSecret: "github-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", profile)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "/auth/oauth/callback")
require.NotContains(t, location, "access_token=")
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
session, err := client.PendingAuthSession.Query().Only(ctx)
require.NoError(t, err)
require.Equal(t, "github", session.ProviderType)
require.Equal(t, "github", session.ProviderKey)
require.Equal(t, "github-123", session.ProviderSubject)
require.Equal(t, "fresh@example.com", session.ResolvedEmail)
require.Equal(t, "/dashboard", session.RedirectTo)
require.Nil(t, session.TargetUserID)
completion, ok := readCompletionResponse(session.LocalFlowState)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "invitation_required", completion["error"])
require.Equal(t, true, completion["invitation_required"])
require.Equal(t, "fresh@example.com", completion["email"])
require.Equal(t, "fresh@example.com", completion["resolved_email"])
require.Equal(t, true, completion["create_account_allowed"])
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName))
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName))
}
func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("existing@example.com").
SetUsername("existing").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil)
handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "google-client",
ClientSecret: "google-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
Subject: "google-123",
Email: "existing@example.com",
EmailVerified: true,
Username: "existing",
})
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "access_token=")
require.Contains(t, location, "redirect=%252Fdashboard")
sessionCount, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, sessionCount)
identityCount, err := client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("google"),
authidentity.ProviderSubjectEQ("google-123"),
).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
_ = user
}
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyAffiliateEnabled: "true",
},
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
},
})
ctx := context.Background()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil)
req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")})
c.Request = req
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
Enabled: true,
ClientID: "github-client",
ClientSecret: "github-secret",
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
FrontendRedirectURL: "/auth/oauth/callback",
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
Subject: "github-aff-user",
Email: "aff-user@example.com",
EmailVerified: true,
Username: "aff-user",
})
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
require.Empty(t, affiliateRepo.ensureUserIDs)
require.Empty(t, affiliateRepo.bindCalls)
session, err := client.PendingAuthSession.Query().Only(ctx)
require.NoError(t, err)
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
completion, ok := readCompletionResponse(session.LocalFlowState)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "registration_completion_required", completion["error"])
require.Equal(t, false, completion["invitation_required"])
require.Equal(t, true, completion["create_account_allowed"])
require.Equal(t, true, completion["force_email_on_signup"])
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
}
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002})
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
invitationEnabled: true,
settingValues: map[string]string{
service.SettingKeyAffiliateEnabled: "true",
},
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
},
})
ctx := context.Background()
invitation, err := client.RedeemCode.Create().
SetCode("INVITE456").
SetType(service.RedeemTypeInvitation).
SetStatus(service.StatusUnused).
SetValue(0).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("email-oauth-aff-session-token").
SetIntent(oauthIntentLogin).
SetProviderType("google").
SetProviderKey("google").
SetProviderSubject("google-aff-user").
SetResolvedEmail("pending-aff@example.com").
SetRedirectTo("/dashboard").
SetBrowserSessionKey("browser-aff-key").
SetUpstreamIdentityClaims(map[string]any{
"email": "pending-aff@example.com",
"email_verified": true,
"username": "pending-aff",
"provider": "google",
"provider_key": "google",
"provider_subject": "google-aff-user",
"aff_code": "AFF456",
}).
SetLocalFlowState(map[string]any{
"step": oauthPendingChoiceStep,
"error": "invitation_required",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
c.Request = req
handler.completeEmailOAuthRegistration(c, "google")
require.Equal(t, http.StatusOK, recorder.Code)
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
require.NoError(t, err)
require.NotEmpty(t, user.PasswordHash)
require.NotEqual(t, "secret-123", user.PasswordHash)
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, tamperedCount)
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedInvitation.UsedBy)
require.Equal(t, user.ID, *storedInvitation.UsedBy)
}
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("email-oauth-password-session-token").
SetIntent(oauthIntentLogin).
SetProviderType("github").
SetProviderKey("github").
SetProviderSubject("github-password-user").
SetResolvedEmail("password-required@example.com").
SetRedirectTo("/dashboard").
SetBrowserSessionKey("browser-password-key").
SetUpstreamIdentityClaims(map[string]any{
"email": "password-required@example.com",
"email_verified": true,
"username": "password-required",
"provider": "github",
"provider_key": "github",
"provider_subject": "github-password-user",
}).
SetLocalFlowState(map[string]any{
"step": oauthPendingChoiceStep,
"error": "registration_completion_required",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
c.Request = req
handler.completeEmailOAuthRegistration(c, "github")
require.Equal(t, http.StatusBadRequest, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
}
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "missing scope", http.StatusForbidden)
}))
t.Cleanup(emailServer.Close)
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
EmailsURL: emailServer.URL,
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
require.Error(t, err)
require.Nil(t, profile)
require.Contains(t, err.Error(), "github emails endpoint status 403")
}
type oauthEmailAffiliateBindCall struct {
userID int64
inviterID int64
}
type oauthEmailAffiliateRepoStub struct {
codeOwners map[string]int64
ensureUserIDs []int64
bindCalls []oauthEmailAffiliateBindCall
}
func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub {
return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners}
}
func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) {
r.ensureUserIDs = append(r.ensureUserIDs, userID)
return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil
}
func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) {
userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))]
if !ok {
return nil, service.ErrAffiliateProfileNotFound
}
return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil
}
func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) {
r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID})
return true, nil
}
func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) {
panic("unexpected AccrueQuota call")
}
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
panic("unexpected GetAccruedRebateFromInvitee call")
}
func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) {
panic("unexpected ThawFrozenQuota call")
}
func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) {
panic("unexpected TransferQuotaToBalance call")
}
func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) {
panic("unexpected ListInvitees call")
}
func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error {
panic("unexpected UpdateUserAffCode call")
}
func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) {
panic("unexpected ResetUserAffCode call")
}
func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error {
panic("unexpected SetUserRebateRate call")
}
func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error {
panic("unexpected BatchSetUserRebateRate call")
}
func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
panic("unexpected ListUsersWithCustomSettings call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
panic("unexpected ListAffiliateInviteRecords call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
panic("unexpected ListAffiliateRebateRecords call")
}
func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
panic("unexpected ListAffiliateTransferRecords call")
}
func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) {
panic("unexpected GetAffiliateUserOverview call")
}
func findSetCookieValue(cookies []*http.Cookie, name string) string {
for _, cookie := range cookies {
if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 {
return cookie.Value
}
}
return ""
}

View File

@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct {
emailCache service.EmailCache
settingValues map[string]string
defaultSubAssigner service.DefaultSubscriptionAssigner
affiliateService *service.AffiliateService
affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService
totpCache service.TotpCache
totpEncryptor service.SecretEncryptor
userRepoOptions oauthPendingFlowUserRepoOptions
@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars (
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_affiliates (
user_id INTEGER PRIMARY KEY,
aff_code TEXT NOT NULL UNIQUE,
aff_code_custom BOOLEAN NOT NULL DEFAULT false,
aff_rebate_rate_percent REAL NULL,
inviter_id INTEGER NULL,
aff_count INTEGER NOT NULL DEFAULT 0,
aff_quota REAL NOT NULL DEFAULT 0,
aff_frozen_quota REAL NOT NULL DEFAULT 0,
aff_history_quota REAL NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars (
},
}
settingValues := map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
}
for key, value := range options.settingValues {
settingValues[key] = value
}
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
affiliateService := options.affiliateService
if affiliateService == nil && options.affiliateFactory != nil {
affiliateService = options.affiliateFactory(client, settingSvc)
}
userRepo := &oauthPendingFlowUserRepo{
client: client,
options: options.userRepoOptions,
@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
nil,
nil,
options.defaultSubAssigner,
nil,
affiliateService,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService
@ -2798,6 +2820,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
panic("unexpected UpdateConcurrency call")
}
func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
panic("unexpected BatchSetConcurrency call")
}
func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
panic("unexpected BatchAddConcurrency call")
}
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}

View File

@ -0,0 +1,130 @@
package handler
import (
"context"
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if h == nil || h.contentModerationService == nil {
return nil
}
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
}
func contentModerationStatus(decision *service.ContentModerationDecision) int {
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
return http.StatusForbidden
}
return decision.StatusCode
}
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
return "content_policy_violation"
}
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if h == nil || h.contentModerationService == nil {
return nil
}
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
}
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
if svc == nil || c == nil || c.Request == nil {
return nil
}
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
if reqLog != nil {
reqLog.Info("content_moderation.gateway_check_start",
zap.String("request_id", input.RequestID),
zap.Int64("user_id", input.UserID),
zap.Int64("api_key_id", input.APIKeyID),
zap.String("api_key_name", input.APIKeyName),
zap.Int64p("group_id", input.GroupID),
zap.String("group_name", input.GroupName),
zap.String("endpoint", input.Endpoint),
zap.String("provider", input.Provider),
zap.String("protocol", input.Protocol),
zap.String("model", input.Model),
zap.Int("body_bytes", len(body)),
)
}
decision, err := svc.Check(c.Request.Context(), input)
if err != nil {
if reqLog != nil {
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
}
return nil
}
if reqLog != nil && decision != nil {
reqLog.Info("content_moderation.gateway_check_done",
zap.String("request_id", input.RequestID),
zap.Bool("allowed", decision.Allowed),
zap.Bool("blocked", decision.Blocked),
zap.Bool("flagged", decision.Flagged),
zap.String("action", decision.Action),
zap.Int("status_code", decision.StatusCode),
zap.String("highest_category", decision.HighestCategory),
zap.Float64("highest_score", decision.HighestScore),
)
}
return decision
}
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
input := service.ContentModerationCheckInput{
RequestID: contentModerationRequestID(c.Request.Context()),
UserID: subject.UserID,
Endpoint: GetInboundEndpoint(c),
Provider: contentModerationProvider(apiKey),
Model: strings.TrimSpace(model),
Protocol: protocol,
Body: body,
}
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
input.Provider = strings.TrimSpace(forcedPlatform)
}
if apiKey != nil {
input.APIKeyID = apiKey.ID
input.APIKeyName = apiKey.Name
if apiKey.User != nil {
input.UserEmail = apiKey.User.Email
}
if apiKey.GroupID != nil {
groupID := *apiKey.GroupID
input.GroupID = &groupID
}
if apiKey.Group != nil {
input.GroupName = apiKey.Group.Name
}
}
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
input.Endpoint = c.Request.URL.Path
}
return input
}
func contentModerationProvider(apiKey *service.APIKey) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.Platform)
}
func contentModerationRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
return strings.TrimSpace(requestID)
}
return ""
}

View File

@ -11,6 +11,7 @@ type CustomMenuItem struct {
Label string `json:"label"`
IconSVG string `json:"icon_svg"`
URL string `json:"url"`
PageSlug string `json:"page_slug,omitempty"`
Visibility string `json:"visibility"` // "user" or "admin"
SortOrder int `json:"sort_order"`
}
@ -24,15 +25,19 @@ type CustomEndpoint struct {
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
LoginAgreementMode string `json:"login_agreement_mode"`
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@ -91,6 +96,17 @@ type SystemSettings struct {
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GitHubOAuthClientID string `json:"github_oauth_client_id"`
GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"`
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
GoogleOAuthClientID string `json:"google_oauth_client_id"`
GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
@ -197,6 +213,9 @@ type SystemSettings struct {
// Available Channels feature switch (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
// 风控中心功能开关
RiskControlEnabled bool `json:"risk_control_enabled"`
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
@ -210,45 +229,52 @@ type DefaultSubscriptionSetting struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
LoginAgreementMode string `json:"login_agreement_mode"`
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
LoginAgreementRevision string `json:"login_agreement_revision"`
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
@ -256,6 +282,14 @@ type PublicSettings struct {
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
}
type LoginAgreementDocument struct {
ID string `json:"id"`
Title string `json:"title"`
ContentMD string `json:"content_md"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO

View File

@ -46,6 +46,7 @@ type GatewayHandler struct {
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
contentModerationService *service.ContentModerationService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
requestEventBus *service.RequestEventBus
@ -68,6 +69,7 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
contentModerationService *service.ContentModerationService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
@ -103,6 +105,7 @@ func NewGatewayHandler(
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
contentModerationService: contentModerationService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
requestEventBus: requestEventBus,
@ -215,6 +218,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Track if we've started streaming (for error handling)
streamStarted := false

View File

@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)

View File

@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)

View File

@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
googleError(c, contentModerationStatus(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名

View File

@ -33,6 +33,7 @@ type AdminHandlers struct {
Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
ContentModeration *admin.ContentModerationHandler
Payment *admin.PaymentHandler
Windsurf *admin.WindsurfHandler
Affiliate *admin.AffiliateHandler

View File

@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)

View File

@ -27,15 +27,16 @@ import (
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
contentModerationService *service.ContentModerationService
concurrencyHelper *ConcurrencyHelper
imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
}
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
contentModerationService *service.ContentModerationService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
contentModerationService: contentModerationService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
}
@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
return
}
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
return
@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow()
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
message := strings.TrimSpace(decision.Message)
if message == "" {
message = "content moderation blocked this request"
}
payload, err := json.Marshal(gin.H{
"event_id": "evt_content_moderation_blocked",
"type": "error",
"error": gin.H{
"type": "invalid_request_error",
"code": contentModerationErrorCode(decision),
"message": message,
},
})
if err != nil {
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
}
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
_ = conn.Write(writeCtx, coderws.MessageText, payload)
}
func summarizeWSCloseErrorForLog(err error) (string, string) {
if err == nil {
return "-", "-"

View File

@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
type contentModerationHandlerSettingRepo struct {
values map[string]string
}
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
if value, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: value}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationHandlerTestRepo struct {
logs []service.ContentModerationLog
}
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
return &service.ContentModerationCleanupResult{}, nil
}
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
gin.SetMode(gin.TestMode)
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
}))
defer moderationServer.Close()
cfg := &service.ContentModerationConfig{
Enabled: true,
Mode: service.ContentModerationModePreBlock,
BaseURL: moderationServer.URL,
Model: "omni-moderation-latest",
APIKeys: []string{"sk-test"},
SampleRate: 100,
AllGroups: true,
BlockMessage: "内容审计测试阻断",
}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationHandlerTestRepo{}
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
service.SettingKeyRiskControlEnabled: "true",
service.SettingKeyContentModerationConfig: string(rawCfg),
}}
moderationSvc := service.NewContentModerationService(
settingRepo,
repo,
nil,
nil,
nil,
nil,
nil,
)
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
UserID: 1,
Endpoint: "/v1/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: service.ContentModerationProtocolOpenAIResponses,
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
repo.logs = nil
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
contentModerationService: moderationSvc,
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
}
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
"type":"response.create",
"model":"gpt-5.5",
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, payload, readErr := clientConn.Read(readCtx)
cancelRead()
if readErr == nil {
require.Contains(t, string(payload), "content_policy_violation")
require.Contains(t, string(payload), "内容审计测试阻断")
} else {
var closeErr coderws.CloseError
require.ErrorAs(t, readErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
}
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,

View File

@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
return
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked {
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
return
}
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
if !acquired {
return

View File

@ -0,0 +1,283 @@
package handler
import (
"encoding/json"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
const maxPageFileSize = 1 << 20 // 1MB
type PageHandler struct {
pagesDir string
settingService *service.SettingService
}
func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler {
pagesDir := filepath.Join(dataDir, "pages")
_ = os.MkdirAll(pagesDir, 0755)
return &PageHandler{pagesDir: pagesDir, settingService: settingService}
}
// GetPageContent serves raw markdown content for a given slug.
// GET /api/v1/pages/:slug
func (h *PageHandler) GetPageContent(c *gin.Context) {
slug := c.Param("slug")
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
response.BadRequest(c, "Invalid page slug")
return
}
// Visibility check: slug must be configured in custom_menu_items
// and the user must have permission based on visibility setting
if !h.checkSlugVisibility(c, slug) {
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
return
}
filePath := filepath.Join(h.pagesDir, slug+".md")
cleaned := filepath.Clean(filePath)
if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) {
response.BadRequest(c, "Invalid page slug")
return
}
info, err := os.Stat(cleaned)
if err != nil || info.IsDir() {
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
return
}
if info.Size() > maxPageFileSize {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"})
return
}
content, err := os.ReadFile(cleaned)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"})
return
}
c.Data(http.StatusOK, "text/markdown; charset=utf-8", content)
}
// ListPages returns available page slugs.
// GET /api/v1/pages
func (h *PageHandler) ListPages(c *gin.Context) {
entries, err := os.ReadDir(h.pagesDir)
if err != nil {
response.Success(c, []string{})
return
}
slugs := make([]string, 0, len(entries))
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if strings.HasSuffix(name, ".md") {
slugs = append(slugs, strings.TrimSuffix(name, ".md"))
}
}
response.Success(c, slugs)
}
// ServePageImage serves images from data/pages/{slug}/ directory.
// GET /api/v1/pages/:slug/images/*filename
// No JWT required (browser img tags can't carry tokens), but visibility is checked.
func (h *PageHandler) ServePageImage(c *gin.Context) {
slug := c.Param("slug")
filename := c.Param("filename")
filename = strings.TrimPrefix(filename, "/")
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
c.Status(http.StatusNotFound)
return
}
if !h.checkImageSlugVisibility(c, slug) {
c.Status(http.StatusNotFound)
return
}
imagesDir := filepath.Join(h.pagesDir, slug)
cleaned, ok := resolvePageImagePath(h.pagesDir, imagesDir, filename)
if !ok {
c.Status(http.StatusNotFound)
return
}
info, err := os.Stat(cleaned)
if err != nil || info.IsDir() {
c.Status(http.StatusNotFound)
return
}
c.File(cleaned)
}
func resolvePageImagePath(pagesDir, imagesDir, filename string) (string, bool) {
relPath, ok := cleanPageImageRelativePath(filename)
if !ok {
return "", false
}
cleanedPagesDir := filepath.Clean(pagesDir)
cleanedImagesDir := filepath.Clean(imagesDir)
cleanedTarget := filepath.Clean(filepath.Join(cleanedImagesDir, relPath))
if !isPathWithinBase(cleanedTarget, cleanedImagesDir) {
return "", false
}
realPagesDir, err := filepath.EvalSymlinks(cleanedPagesDir)
if err != nil {
return "", false
}
realImagesDir, err := filepath.EvalSymlinks(cleanedImagesDir)
if err != nil || !isPathWithinBase(realImagesDir, realPagesDir) {
return "", false
}
realTarget, err := filepath.EvalSymlinks(cleanedTarget)
if err != nil || !isPathWithinBase(realTarget, realImagesDir) {
return "", false
}
return realTarget, true
}
func cleanPageImageRelativePath(filename string) (string, bool) {
if filename == "" {
return "", false
}
if strings.HasPrefix(filename, "/") {
return "", false
}
decoded, err := url.PathUnescape(filename)
if err != nil {
return "", false
}
if decoded == "" || strings.HasPrefix(decoded, "/") || strings.Contains(decoded, "\\") || strings.ContainsRune(decoded, 0) {
return "", false
}
parts := make([]string, 0)
for _, part := range strings.Split(decoded, "/") {
switch part {
case "", ".":
continue
case "..":
return "", false
default:
parts = append(parts, part)
}
}
if len(parts) == 0 {
return "", false
}
relPath := filepath.Join(parts...)
if filepath.IsAbs(relPath) || filepath.VolumeName(relPath) != "" {
return "", false
}
return relPath, true
}
func isPathWithinBase(path, base string) bool {
rel, err := filepath.Rel(filepath.Clean(base), filepath.Clean(path))
if err != nil {
return false
}
return rel != "." && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator))
}
// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found).
func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) {
if h.settingService == nil {
return "", false
}
raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context())
if raw == "" || raw == "[]" {
return "", false
}
var items []struct {
URL string `json:"url"`
PageSlug string `json:"page_slug"`
Visibility string `json:"visibility"`
}
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return "", false
}
for _, item := range items {
itemSlug := item.PageSlug
if itemSlug == "" && strings.HasPrefix(item.URL, "md:") {
itemSlug = strings.TrimPrefix(item.URL, "md:")
}
if itemSlug == slug {
return item.Visibility, true
}
}
return "", false
}
// checkSlugVisibility verifies the slug is configured in custom_menu_items
// and the authenticated user has permission to view it.
func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool {
visibility, found := h.findSlugVisibility(c, slug)
if !found {
return false
}
if visibility == "admin" {
role, _ := middleware2.GetUserRoleFromContext(c)
return role == "admin"
}
return true
}
// checkImageSlugVisibility checks visibility for image requests (no JWT available).
// Only allows user-visible pages; admin-only pages are blocked.
func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool {
visibility, found := h.findSlugVisibility(c, slug)
if !found {
return false
}
return visibility != "admin"
}
// RegisterPageRoutes registers page routes on a router group.
func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) {
h := NewPageHandler(dataDir, settingService)
// Authenticated page content (JWT required + visibility check)
pages := v1.Group("/pages")
pages.Use(jwtAuth)
{
pages.GET("/:slug", h.GetPageContent)
}
// Images: no JWT (browser img tags can't carry tokens), visibility check in handler
pageImages := v1.Group("/pages")
{
pageImages.GET("/:slug/images/*filename", h.ServePageImage)
}
// Admin-only: list all available pages
adminPages := v1.Group("/pages")
adminPages.Use(adminAuth)
{
adminPages.GET("", h.ListPages)
}
}

View File

@ -0,0 +1,102 @@
package handler
import (
"os"
"path/filepath"
"testing"
)
func TestCleanPageImageRelativePath(t *testing.T) {
tests := []struct {
name string
in string
want string
ok bool
}{
{name: "single filename", in: "logo.png", want: "logo.png", ok: true},
{name: "nested path", in: "images/logo.png", want: filepath.Join("images", "logo.png"), ok: true},
{name: "dot prefix", in: "./logo.png", want: "logo.png", ok: true},
{name: "url escaped slash", in: "images%2Flogo.png", want: filepath.Join("images", "logo.png"), ok: true},
{name: "parent traversal", in: "../secret.png", ok: false},
{name: "encoded parent traversal", in: "%2e%2e/secret.png", ok: false},
{name: "backslash traversal", in: `images\secret.png`, ok: false},
{name: "absolute path", in: "/etc/passwd", ok: false},
{name: "encoded absolute path", in: "%2fetc/passwd", ok: false},
{name: "encoded nul byte", in: "logo.png%00", ok: false},
{name: "invalid escape", in: "logo.png%zz", ok: false},
{name: "empty path", in: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := cleanPageImageRelativePath(tt.in)
if ok != tt.ok {
t.Fatalf("ok = %v, want %v", ok, tt.ok)
}
if got != tt.want {
t.Fatalf("path = %q, want %q", got, tt.want)
}
})
}
}
func TestResolvePageImagePath(t *testing.T) {
root := t.TempDir()
pagesDir := filepath.Join(root, "pages")
base := filepath.Join(pagesDir, "guide")
if err := os.MkdirAll(filepath.Join(base, "images"), 0755); err != nil {
t.Fatalf("create images dir: %v", err)
}
if err := os.WriteFile(filepath.Join(base, "logo.png"), []byte("fake"), 0644); err != nil {
t.Fatalf("create direct image: %v", err)
}
if err := os.WriteFile(filepath.Join(base, "images", "logo.png"), []byte("fake"), 0644); err != nil {
t.Fatalf("create image: %v", err)
}
got, ok := resolvePageImagePath(pagesDir, base, "logo.png")
if !ok {
t.Fatal("expected direct image path to be accepted")
}
want := filepath.Join(base, "logo.png")
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
got, ok = resolvePageImagePath(pagesDir, base, "images/logo.png")
if !ok {
t.Fatal("expected nested image path to be accepted")
}
want = filepath.Join(base, "images", "logo.png")
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
if got, ok := resolvePageImagePath(pagesDir, base, "../guide.md"); ok {
t.Fatalf("expected traversal to be rejected, got %q", got)
}
}
func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
root := t.TempDir()
pagesDir := filepath.Join(root, "pages")
base := filepath.Join(pagesDir, "guide")
outside := filepath.Join(root, "outside")
if err := os.MkdirAll(base, 0755); err != nil {
t.Fatalf("create page dir: %v", err)
}
if err := os.MkdirAll(outside, 0755); err != nil {
t.Fatalf("create outside dir: %v", err)
}
if err := os.WriteFile(filepath.Join(outside, "secret.png"), []byte("secret"), 0644); err != nil {
t.Fatalf("create outside file: %v", err)
}
if err := os.Symlink(outside, filepath.Join(base, "images")); err != nil {
t.Skipf("symlink not supported: %v", err)
}
if got, ok := resolvePageImagePath(pagesDir, base, "images/secret.png"); ok {
t.Fatalf("expected symlink escape to be rejected, got %q", got)
}
}

View File

@ -40,6 +40,11 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
LoginAgreementEnabled: settings.LoginAgreementEnabled,
LoginAgreementMode: settings.LoginAgreementMode,
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
LoginAgreementRevision: settings.LoginAgreementRevision,
LoginAgreementDocuments: publicLoginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
@ -63,6 +68,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: h.version,
@ -77,5 +84,19 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
RiskControlEnabled: settings.RiskControlEnabled,
})
}
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
result := make([]dto.LoginAgreementDocument, 0, len(items))
for _, item := range items {
result = append(result, dto.LoginAgreementDocument{
ID: item.ID,
Title: item.Title,
ContentMD: item.ContentMD,
})
}
return result
}

View File

@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil

View File

@ -36,6 +36,7 @@ func ProvideAdminHandlers(
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
contentModerationHandler *admin.ContentModerationHandler,
paymentHandler *admin.PaymentHandler,
windsurfHandler *admin.WindsurfHandler,
affiliateHandler *admin.AffiliateHandler,
@ -68,6 +69,7 @@ func ProvideAdminHandlers(
Channel: channelHandler,
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
ContentModeration: contentModerationHandler,
Payment: paymentHandler,
Windsurf: windsurfHandler,
Affiliate: affiliateHandler,
@ -180,6 +182,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewContentModerationHandler,
admin.NewPaymentHandler,
admin.NewAffiliateHandler,

View File

@ -317,6 +317,7 @@ const CLICurrentVersion = "2.1.92"
// - OAuth 账号 + 非 haiku追加这整份列表再按需保留 client 带来的 beta。
// - OAuth 账号 + haikuAnthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
// - 不默认加入 redact-thinking避免上游抹除 thinking 内容;客户端显式传入时由合并逻辑保留。
func FullClaudeCodeMimicryBetas() []string {
return []string{
BetaClaudeCode,
@ -324,7 +325,6 @@ func FullClaudeCodeMimicryBetas() []string {
BetaInterleavedThinking,
BetaPromptCachingScope,
BetaEffort,
BetaRedactThinking,
BetaContextManagement,
BetaExtendedCacheTTL,
}

View File

@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldName,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,

View File

@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S
got, err := repo.GetByKeyForAuth(ctx, key.Key)
require.NoError(t, err)
require.Equal(t, key.Name, got.Name)
require.NotNil(t, got.Group)
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
}

View File

@ -0,0 +1,71 @@
package repository
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes"
type contentModerationHashCache struct {
rdb *redis.Client
}
func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache {
return &contentModerationHashCache{rdb: rdb}
}
func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return nil
}
return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err()
}
func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return false, nil
}
return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
}
func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
inputHash = strings.TrimSpace(inputHash)
if c == nil || c.rdb == nil || inputHash == "" {
return false, nil
}
deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
if err != nil {
return false, err
}
return deleted > 0, nil
}
func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
if c == nil || c.rdb == nil {
return 0, nil
}
deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
if err != nil {
return 0, err
}
if deleted == 0 {
return 0, nil
}
if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil {
return 0, err
}
return deleted, nil
}
func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
if c == nil || c.rdb == nil {
return 0, nil
}
return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
}

View File

@ -0,0 +1,274 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type contentModerationRepository struct {
db *sql.DB
}
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
return &contentModerationRepository{db: db}
}
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log == nil {
return nil
}
categoryScores, err := json.Marshal(log.CategoryScores)
if err != nil {
return fmt.Errorf("marshal moderation category scores: %w", err)
}
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
if err != nil {
return fmt.Errorf("marshal moderation thresholds: %w", err)
}
var userID any
if log.UserID != nil {
userID = *log.UserID
}
var apiKeyID any
if log.APIKeyID != nil {
apiKeyID = *log.APIKeyID
}
var groupID any
if log.GroupID != nil {
groupID = *log.GroupID
}
var latency any
if log.UpstreamLatencyMS != nil {
latency = *log.UpstreamLatencyMS
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO content_moderation_logs (
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
violation_count, auto_banned, email_sent, queue_delay_ms
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9, $10, $11, $12, $13, $14, $15,
$16::jsonb, $17::jsonb, $18, $19, $20,
$21, $22, $23, $24
) RETURNING id, created_at`,
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
).Scan(&log.ID, &log.CreatedAt)
if err != nil {
return fmt.Errorf("insert content moderation log: %w", err)
}
return nil
}
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
where, args := buildContentModerationLogWhere(filter)
whereSQL := "WHERE " + strings.Join(where, " AND ")
var total int64
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
}
params := filter.Pagination
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
queryArgs := append([]any{}, args...)
queryArgs = append(queryArgs, params.Limit(), params.Offset())
rows, err := r.db.QueryContext(ctx, `
SELECT
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
FROM content_moderation_logs l
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
ORDER BY l.created_at DESC, l.id DESC
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
queryArgs...,
)
if err != nil {
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
}
defer func() { _ = rows.Close() }()
items := make([]service.ContentModerationLog, 0)
for rows.Next() {
var item service.ContentModerationLog
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
var scoresRaw, thresholdsRaw []byte
if err := rows.Scan(
&item.ID,
&item.RequestID,
&userID,
&item.UserEmail,
&apiKeyID,
&item.APIKeyName,
&groupID,
&item.GroupName,
&item.Endpoint,
&item.Provider,
&item.Model,
&item.Mode,
&item.Action,
&item.Flagged,
&item.HighestCategory,
&item.HighestScore,
&scoresRaw,
&thresholdsRaw,
&item.InputExcerpt,
&latency,
&item.Error,
&item.ViolationCount,
&item.AutoBanned,
&item.EmailSent,
&item.UserStatus,
&queueDelay,
&item.CreatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
}
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
if latency.Valid {
v := int(latency.Int64)
item.UpstreamLatencyMS = &v
}
if queueDelay.Valid {
v := int(queueDelay.Int64)
item.QueueDelayMS = &v
}
item.CategoryScores = map[string]float64{}
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
item.ThresholdSnapshot = map[string]float64{}
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
}
return items, paginationResultFromTotal(total, params), nil
}
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
if userID <= 0 {
return 0, nil
}
var count int
err := r.db.QueryRowContext(ctx, `
WITH last_auto_ban AS (
SELECT MAX(created_at) AS at
FROM content_moderation_logs
WHERE user_id = $1 AND auto_banned = TRUE
)
SELECT COUNT(*)
FROM content_moderation_logs
WHERE user_id = $1
AND flagged = TRUE
AND created_at >= $2
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
`, userID, since).Scan(&count)
if err != nil {
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
}
return count, nil
}
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
if r == nil || r.db == nil {
return result, nil
}
hitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = TRUE AND created_at < $1
`, hitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
}
result.DeletedHit, _ = hitExec.RowsAffected()
nonHitExec, err := r.db.ExecContext(ctx, `
DELETE FROM content_moderation_logs
WHERE flagged = FALSE AND created_at < $1
`, nonHitBefore)
if err != nil {
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
}
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
result.FinishedAt = time.Now()
return result, nil
}
func nullableIntPtr(value *int) any {
if value == nil {
return nil
}
return *value
}
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
where := []string{"l.id IS NOT NULL"}
args := make([]any, 0)
add := func(expr string, value any) {
args = append(args, value)
where = append(where, fmt.Sprintf(expr, len(args)))
}
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
case "hit", "flagged":
where = append(where, "l.flagged = TRUE")
case "blocked", "block":
where = append(where, "l.action = 'block'")
case "pass", "allow":
where = append(where, "l.flagged = FALSE AND l.error = ''")
case "error":
where = append(where, "l.error <> ''")
}
if filter.GroupID != nil {
add("l.group_id = $%d", *filter.GroupID)
}
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
add("l.endpoint = $%d", endpoint)
}
if search := strings.TrimSpace(filter.Search); search != "" {
like := "%" + search + "%"
args = append(args, like, like, like, like, like)
idx := len(args) - 4
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
}
if filter.From != nil && !filter.From.IsZero() {
add("l.created_at >= $%d", *filter.From)
}
if filter.To != nil && !filter.To.IsZero() {
add("l.created_at <= $%d", *filter.To)
}
return where, args
}

View File

@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
if len(userIDs) == 0 {
return 0, nil
}
if value < 0 {
value = 0
}
res, err := r.sql.ExecContext(ctx,
"UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
value, pq.Array(userIDs))
if err != nil {
return 0, fmt.Errorf("batch set concurrency: %w", err)
}
affected, _ := res.RowsAffected()
return int(affected), nil
}
func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
if len(userIDs) == 0 {
return 0, nil
}
res, err := r.sql.ExecContext(ctx,
"UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
delta, pq.Array(userIDs))
if err != nil {
return 0, fmt.Errorf("batch add concurrency: %w", err)
}
affected, _ := res.RowsAffected()
return int(affected), nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}

View File

@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewContentModerationRepository,
NewAffiliateRepository,
// Cache implementations
@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet(
NewRefreshTokenCache,
NewErrorPassthroughCache,
NewTLSFingerprintProfileCache,
NewContentModerationHashCache,
// Encryptors
NewAESEncryptor,

View File

@ -646,12 +646,21 @@ func TestAPIContracts(t *testing.T) {
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"frontend_url": "",
"totp_enabled": false,
"totp_encryption_key_configured": false,
"login_agreement_enabled": false,
"login_agreement_mode": "modal",
"login_agreement_updated_at": "2026-03-31",
"login_agreement_documents": [
{"id": "terms", "title": "服务条款", "content_md": ""},
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
],
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
@ -685,6 +694,16 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"github_oauth_enabled": false,
"github_oauth_client_id": "",
"github_oauth_client_secret_configured": false,
"github_oauth_redirect_url": "",
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
"google_oauth_enabled": false,
"google_oauth_client_id": "",
"google_oauth_client_secret_configured": false,
"google_oauth_redirect_url": "",
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
@ -700,6 +719,16 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_github_balance": 0,
"auth_source_default_github_concurrency": 5,
"auth_source_default_github_subscriptions": [],
"auth_source_default_github_grant_on_signup": false,
"auth_source_default_github_grant_on_first_bind": false,
"auth_source_default_google_balance": 0,
"auth_source_default_google_concurrency": 5,
"auth_source_default_google_subscriptions": [],
"auth_source_default_google_grant_on_signup": false,
"auth_source_default_google_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
@ -792,6 +821,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"risk_control_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
@ -859,12 +889,21 @@ func TestAPIContracts(t *testing.T) {
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"invitation_code_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "",
"smtp_port": 587,
"smtp_username": "",
"invitation_code_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"login_agreement_enabled": false,
"login_agreement_mode": "modal",
"login_agreement_updated_at": "2026-03-31",
"login_agreement_documents": [
{"id": "terms", "title": "服务条款", "content_md": ""},
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
],
"smtp_host": "",
"smtp_port": 587,
"smtp_username": "",
"smtp_password_configured": false,
"smtp_from_email": "",
"smtp_from_name": "",
@ -898,6 +937,16 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"github_oauth_enabled": false,
"github_oauth_client_id": "",
"github_oauth_client_secret_configured": false,
"github_oauth_redirect_url": "",
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
"google_oauth_enabled": false,
"google_oauth_client_id": "",
"google_oauth_client_secret_configured": false,
"google_oauth_redirect_url": "",
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
@ -983,6 +1032,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"risk_control_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
@ -1005,6 +1055,16 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_github_balance": 0,
"auth_source_default_github_concurrency": 5,
"auth_source_default_github_subscriptions": [],
"auth_source_default_github_grant_on_signup": false,
"auth_source_default_github_grant_on_first_bind": false,
"auth_source_default_google_balance": 0,
"auth_source_default_google_concurrency": 5,
"auth_source_default_google_subscriptions": [],
"auth_source_default_google_grant_on_signup": false,
"auth_source_default_google_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
@ -1123,7 +1183,7 @@ func newContractDeps(t *testing.T) *contractDeps {
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil)
redeemHandler := handler.NewRedeemHandler(redeemService)
settingRepo := newStubSettingRepo()
@ -1294,6 +1354,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
return errors.New("not implemented")
}
func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return false, errors.New("not implemented")
}

View File

@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}

View File

@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool {
"/auth/oauth/wechat/callback",
"/auth/oauth/wechat/payment/callback",
"/auth/oauth/oidc/callback",
"/auth/oauth/github/callback",
"/auth/oauth/google/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",

View File

@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/oauth/oidc/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_github_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/github/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_github_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/github/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_google_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/google/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_google_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/google/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",

View File

@ -120,4 +120,6 @@ func registerRoutes(
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster)
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService)
}

View File

@ -95,11 +95,28 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
// 风控中心
registerContentModerationRoutes(admin, h)
// 邀请返利(专属用户管理)
registerAffiliateRoutes(admin, h)
}
}
func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
risk := admin.Group("/risk-control")
{
risk.GET("/config", h.Admin.ContentModeration.GetConfig)
risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig)
risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys)
risk.GET("/status", h.Admin.ContentModeration.GetStatus)
risk.GET("/logs", h.Admin.ContentModeration.ListLogs)
risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser)
risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash)
risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes)
}
}
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
@ -234,6 +251,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
@ -270,6 +288,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
accounts.POST("/import/codex-session", h.Admin.Account.ImportCodexSession)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)

View File

@ -63,6 +63,22 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart)
auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback)
auth.POST("/oauth/github/complete-registration",
rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteGitHubOAuthRegistration,
)
auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart)
auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback)
auth.POST("/oauth/google/complete-registration",
rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteGoogleOAuthRegistration,
)
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")

View File

@ -33,6 +33,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
cleaned := make([]int64, 0, len(userIDs))
for _, uid := range userIDs {
if uid > 0 {
cleaned = append(cleaned, uid)
}
}
if len(cleaned) == 0 {
return 0, nil
}
var affected int
var err error
switch mode {
case "set":
affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value)
case "add":
affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value)
default:
return 0, errors.New("invalid mode: must be 'set' or 'add'")
}
if err != nil {
return 0, err
}
if s.authCacheInvalidator != nil {
for _, uid := range cleaned {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid)
}
}
return affected, nil
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {

View File

@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
panic("unexpected")
}

View File

@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i
panic("unexpected UpdateConcurrency call")
}
func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
if s.existsErr != nil {
return false, s.existsErr

View File

@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64)
return 0, nil
}
func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {

View File

@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Name string `json:"name"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`

View File

@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
type apiKeyAuthCacheConfig struct {
l1Size int
@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Name: apiKey.Name,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Name: snapshot.Name,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,

View File

@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Name: "Audit Key",
Status: StatusActive,
User: &User{
ID: 2,
@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.Equal(t, apiKey.Name, roundTrip.Name)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}

View File

@ -0,0 +1,274 @@
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
type EmailOAuthIdentityInput struct {
ProviderType string
ProviderKey string
ProviderSubject string
Email string
EmailVerified bool
Username string
DisplayName string
AvatarURL string
UpstreamMetadata map[string]any
}
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) {
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "")
}
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation(
ctx context.Context,
input EmailOAuthIdentityInput,
invitationCode string,
affiliateCode string,
) (*TokenPair, *User, error) {
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode)
}
func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
ctx context.Context,
input EmailOAuthIdentityInput,
invitationCode string,
affiliateCode string,
) (*TokenPair, *User, error) {
if s == nil || s.userRepo == nil || s.entClient == nil {
return nil, nil, ErrServiceUnavailable
}
providerType := normalizeOAuthSignupSource(input.ProviderType)
if providerType != "github" && providerType != "google" {
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
}
providerKey := strings.TrimSpace(input.ProviderKey)
if providerKey == "" {
providerKey = providerType
}
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerSubject == "" {
return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing")
}
if !input.EmailVerified {
return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified")
}
email := strings.TrimSpace(strings.ToLower(input.Email))
if email == "" || len(email) > 255 {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject)
if err != nil {
return nil, nil, err
}
if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
}
user := identityUser
created := false
if user == nil {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode)
if err != nil {
return nil, nil, err
}
created = true
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err)
return nil, nil, ErrServiceUnavailable
}
}
}
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{
ProviderType: providerType,
ProviderKey: providerKey,
ProviderSubject: providerSubject,
Email: email,
EmailVerified: input.EmailVerified,
Username: input.Username,
DisplayName: input.DisplayName,
AvatarURL: input.AvatarURL,
UpstreamMetadata: input.UpstreamMetadata,
}); err != nil {
return nil, nil, err
}
if user.Username == "" && strings.TrimSpace(input.Username) != "" {
user.Username = strings.TrimSpace(input.Username)
if err := s.userRepo.Update(ctx, user); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err)
}
}
if !created {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err)
}
}
s.RecordSuccessfulLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) {
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, ErrRegDisabled
}
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrInvitationCodeRequired) {
return nil, ErrOAuthInvitationRequired
}
return nil, err
}
randomPassword, err := randomHexString(32)
if err != nil {
return nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
grantPlan := s.resolveSignupGrantPlan(ctx, providerType)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
user := &User{
Email: email,
Username: strings.TrimSpace(username),
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: providerType,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
existing, loadErr := s.userRepo.GetByEmail(ctx, email)
if loadErr != nil {
return nil, ErrServiceUnavailable
}
return existing, nil
}
return nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, providerType, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode)
return nil, ErrInvitationCodeInvalid
}
}
return user, nil
}
func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) {
identity, err := s.entClient.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
user, err := s.userRepo.GetByID(ctx, identity.UserID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, nil
}
return nil, ErrServiceUnavailable
}
return user, nil
}
func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error {
metadata := map[string]any{
"email": strings.TrimSpace(strings.ToLower(input.Email)),
"email_verified": input.EmailVerified,
}
for key, value := range input.UpstreamMetadata {
metadata[key] = value
}
if strings.TrimSpace(input.Username) != "" {
metadata["username"] = strings.TrimSpace(input.Username)
}
if strings.TrimSpace(input.DisplayName) != "" {
metadata["display_name"] = strings.TrimSpace(input.DisplayName)
}
if strings.TrimSpace(input.AvatarURL) != "" {
metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL)
}
providerType := normalizeOAuthSignupSource(input.ProviderType)
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
identity, err := s.entClient.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil {
if identity.UserID != userID {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
_, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID).
SetMetadata(metadata).
Save(ctx)
return err
}
_, err = s.entClient.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetMetadata(metadata).
Save(ctx)
return err
}

View File

@ -10,6 +10,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func normalizeOAuthSignupSource(signupSource string) string {
@ -17,7 +18,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
case "linuxdo", "wechat", "oidc", "github", "google":
return signupSource
default:
return "email"
@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return tokenPair, user, nil
}
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
// provider that has already returned a verified email address.
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
ctx context.Context,
email string,
password string,
invitationCode string,
signupSource string,
) (*TokenPair, *User, error) {
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
email = strings.TrimSpace(strings.ToLower(email))
if email == "" || len(email) > 255 {
return nil, nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
if strings.TrimSpace(password) == "" {
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
return nil, nil, ErrEmailExists
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
return nil, nil, ErrServiceUnavailable
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(

View File

@ -229,6 +229,67 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) {
tests := []struct {
name string
email string
signupSource string
want string
}{
{
name: "github",
email: "github@example.com",
signupSource: " GitHub ",
want: "github",
},
{
name: "google",
email: "google@example.com",
signupSource: " Google ",
want: "google",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
tt.email,
"secret-123",
"246810",
"",
tt.signupSource,
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, tt.want, userRepo.created[0].SignupSource)
})
}
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
@ -256,7 +317,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
"secret-123",
"246810",
"",
"github",
"unknown-provider",
)
require.NoError(t, err)

View File

@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
return defaults.OIDC, true
case "wechat":
return defaults.WeChat, true
case "github":
return defaults.GitHub, true
case "google":
return defaults.Google, true
default:
return ProviderDefaultGrantSettings{}, false
}

View File

@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
return ok, nil
}
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}

View File

@ -0,0 +1,64 @@
package service
import "strings"
const featureKeyCodexImageGenerationBridge = "codex_image_generation_bridge"
func boolOverridePtr(v bool) *bool {
return &v
}
func boolOverrideFromMap(values map[string]any, keys ...string) *bool {
if values == nil {
return nil
}
for _, key := range keys {
if v, ok := values[key].(bool); ok {
return boolOverridePtr(v)
}
}
return nil
}
func platformBoolOverride(values map[string]any, key string, platform string) *bool {
if values == nil {
return nil
}
if v, ok := values[key].(bool); ok {
return boolOverridePtr(v)
}
raw, ok := values[key].(map[string]any)
if !ok {
return nil
}
platform = strings.TrimSpace(platform)
if platform == "" {
return nil
}
if v, ok := raw[platform].(bool); ok {
return boolOverridePtr(v)
}
return nil
}
// CodexImageGenerationBridgeOverride returns the channel-level override for Codex
// image_generation bridge injection. Nil means follow the global/account policy.
func (c *Channel) CodexImageGenerationBridgeOverride(platform string) *bool {
if c == nil {
return nil
}
return platformBoolOverride(c.FeaturesConfig, featureKeyCodexImageGenerationBridge, platform)
}
// CodexImageGenerationBridgeOverride returns the account-level override for Codex
// image_generation bridge injection. Nil means follow the channel/global policy.
func (a *Account) CodexImageGenerationBridgeOverride() *bool {
if a == nil || a.Platform != PlatformOpenAI || a.Extra == nil {
return nil
}
if override := boolOverrideFromMap(a.Extra, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled"); override != nil {
return override
}
openaiConfig, _ := a.Extra[PlatformOpenAI].(map[string]any)
return boolOverrideFromMap(openaiConfig, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled")
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
package service
import (
"fmt"
"html"
"strings"
"time"
)
func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
statusBlock := ""
if log.AutoBanned {
statusBlock = `<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>`
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 风控提醒</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户触发内容审计规则</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>您的 API 请求在内容审计中触发平台风控策略详情如下</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">触发详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 阈值 %d</td></tr>
</table>
</div>
%s
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送请勿回复</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
statusBlock,
html.EscapeString(siteName),
)
}
func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
if log == nil {
return ""
}
userName := strings.TrimSpace(log.UserEmail)
if userName == "" && log.UserID != nil {
userName = fmt.Sprintf("UID %d", *log.UserID)
}
threshold := cfg.BanThreshold
if threshold <= 0 {
threshold = defaultContentModerationBanThreshold
}
return fmt.Sprintf(`<!doctype html>
<html>
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 账户封禁</div>
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户已被自动禁用</h1>
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>您的账户在计数周期内多次触发平台风控策略系统已自动禁用该账户详情如下</p>
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">封禁详情</h2>
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">封禁时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 阈值 %d</td></tr>
</table>
</div>
<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态所有 API 请求将被拒绝</div>
<p style="font-size:15px;line-height:1.8;color:#666;margin-top:24px;">如需申诉或恢复账号请联系平台管理员处理</p>
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送请勿回复</p>
</div>
</div>
</body>
</html>`,
html.EscapeString(userName),
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
log.HighestScore,
log.ViolationCount,
threshold,
html.EscapeString(siteName),
)
}
func defaultContentModerationString(value string, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return strings.TrimSpace(value)
}

View File

@ -0,0 +1,320 @@
package service
import (
"crypto/rand"
"fmt"
"math/big"
"strings"
"github.com/tidwall/gjson"
)
func ExtractContentModerationText(protocol string, body []byte) string {
return ExtractContentModerationInput(protocol, body).Text
}
func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput {
if len(body) == 0 || !gjson.ValidBytes(body) {
return ContentModerationInput{}
}
var parts []string
var images []string
switch protocol {
case ContentModerationProtocolAnthropicMessages:
collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images)
case ContentModerationProtocolOpenAIChat:
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
case ContentModerationProtocolOpenAIResponses:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
case ContentModerationProtocolGemini:
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
case ContentModerationProtocolOpenAIImages:
addModerationText(&parts, gjson.GetBytes(body, "prompt").String())
collectContentValue(gjson.GetBytes(body, "images"), &parts, &images)
default:
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
}
out := ContentModerationInput{
Text: normalizeContentModerationText(strings.Join(parts, "\n")),
Images: normalizeModerationImages(images),
}
out.Normalize()
return out
}
func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
var candidate []string
var candidateImages []string
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
var candidate []string
var candidateImages []string
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
if !isAnthropicSystemReminderText(value.String()) {
addModerationText(parts, value.String())
}
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectAnthropicUserContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectAnthropicUserContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
collectContentValue(value, parts, images)
}
}
}
func isAnthropicSystemReminderText(text string) bool {
return strings.HasPrefix(strings.TrimSpace(text), "<system-reminder>")
}
func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) {
switch {
case !input.Exists():
return
case input.Type == gjson.String:
addModerationText(parts, input.String())
case input.IsArray():
var last gjson.Result
input.ForEach(func(_, item gjson.Result) bool {
if isResponsesUserTextItem(item) {
last = item
}
return true
})
if last.Exists() {
collectContentValue(last.Get("content"), parts, images)
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
collectContentValue(last, parts, images)
}
}
case input.IsObject():
if isResponsesUserTextItem(input) {
collectContentValue(input.Get("content"), parts, images)
if input.Get("type").String() == "input_text" || input.Get("text").Exists() {
collectContentValue(input, parts, images)
}
}
}
}
func isResponsesUserTextItem(item gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(item.Get("role").String()))
if role == "user" {
return responseItemHasModerationText(item)
}
if role != "" {
return false
}
return responseItemHasModerationText(item)
}
func responseItemHasModerationText(item gjson.Result) bool {
var parts []string
var images []string
collectContentValue(item.Get("content"), &parts, &images)
if item.Get("type").String() == "input_text" || item.Get("text").Exists() {
collectContentValue(item, &parts, &images)
}
return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0
}
func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) {
if !contents.IsArray() {
return
}
var lastParts []string
var lastImages []string
contents.ForEach(func(_, content gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
if role == "" || role == "user" {
var candidate []string
var candidateImages []string
if arr := content.Get("parts"); arr.IsArray() {
arr.ForEach(func(_, part gjson.Result) bool {
addModerationText(&candidate, part.Get("text").String())
addGeminiModerationImage(&candidateImages, part)
return true
})
}
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
}
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
switch {
case !value.Exists():
return
case value.Type == gjson.String:
addModerationText(parts, value.String())
case value.IsArray():
value.ForEach(func(_, item gjson.Result) bool {
collectContentValue(item, parts, images)
return true
})
case value.IsObject():
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
addModerationImage(images, value.Get("image_url.url").String())
addModerationImage(images, value.Get("image_url").String())
addModerationImage(images, value.Get("url").String())
addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String())
addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String())
addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String())
addModerationImage(images, value.Get("source.data").String())
addModerationImage(images, value.Get("data").String())
addModerationImage(images, value.Get("base64").String())
switch typ {
case "", "text", "input_text", "message":
if value.Get("text").Exists() {
addModerationText(parts, value.Get("text").String())
}
if value.Get("content").Exists() {
collectContentValue(value.Get("content"), parts, images)
}
case "image_url", "input_image", "image":
}
}
}
func addGeminiModerationImage(images *[]string, part gjson.Result) {
if inlineData := part.Get("inline_data"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mime_type").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
if inlineData := part.Get("inlineData"); inlineData.IsObject() {
mimeType := strings.TrimSpace(inlineData.Get("mimeType").String())
data := strings.TrimSpace(inlineData.Get("data").String())
if mimeType != "" && data != "" {
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
}
addModerationImage(images, part.Get("file_data.file_uri").String())
addModerationImage(images, part.Get("fileData.fileUri").String())
}
func addModerationImageData(images *[]string, mimeType string, data string) {
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return
}
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
}
func addModerationImage(images *[]string, image string) {
image = strings.TrimSpace(image)
if image == "" {
return
}
if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") {
*images = append(*images, image)
}
}
func normalizeModerationImages(images []string) []string {
out := make([]string, 0, len(images))
seen := make(map[string]struct{}, len(images))
for _, image := range images {
image = strings.TrimSpace(image)
if image == "" {
continue
}
if _, ok := seen[image]; ok {
continue
}
seen[image] = struct{}{}
out = append(out, image)
}
return out
}
func limitContentModerationImages(images []string) []string {
if len(images) <= maxContentModerationInputImages {
return images
}
idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(images))))
if err != nil {
return images[:maxContentModerationInputImages]
}
return []string{images[int(idx.Int64())]}
}
func addModerationText(parts *[]string, text string) {
text = strings.TrimSpace(text)
if text == "" {
return
}
if strings.Contains(text, "<system-reminder>") {
return
}
*parts = append(*parts, text)
}
func normalizeContentModerationText(text string) string {
return strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
}

View File

@ -0,0 +1,37 @@
package service
import (
"regexp"
"strings"
)
var contentModerationSecretPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)\bhttps?://[^\s"'<>,。;、]+`),
regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`),
regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`),
regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`),
regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`),
regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`),
regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`),
}
func redactContentModerationSecrets(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
out := text
for idx, pattern := range contentModerationSecretPatterns {
switch idx {
case 1:
out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`)
case 2:
out = pattern.ReplaceAllString(out, `${1}[已脱敏]`)
default:
out = pattern.ReplaceAllString(out, `[已脱敏]`)
}
}
return out
}

File diff suppressed because it is too large Load Diff

View File

@ -108,6 +108,12 @@ const (
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期小时0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限0=无上限)
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置JSON
SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款
SettingKeyLoginAgreementMode = "login_agreement_mode" // 条款确认展示模式modal / checkbox
SettingKeyLoginAgreementUpdatedAt = "login_agreement_updated_at" // 条款更新日期(展示用)
SettingKeyLoginAgreementDocuments = "login_agreement_documents" // 条款文档列表JSONMarkdown 内容)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@ -174,6 +180,18 @@ const (
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// GitHub / Google 邮箱快捷登录设置
SettingKeyGitHubOAuthEnabled = "github_oauth_enabled"
SettingKeyGitHubOAuthClientID = "github_oauth_client_id"
SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret"
SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url"
SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url"
SettingKeyGoogleOAuthEnabled = "google_oauth_enabled"
SettingKeyGoogleOAuthClientID = "google_oauth_client_id"
SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret"
SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url"
SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
@ -217,6 +235,16 @@ const (
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key

View File

@ -124,6 +124,24 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
require.Contains(t, got, "fast-mode-2026-02-01")
}
func TestFullClaudeCodeMimicryBetas_DoesNotDefaultRedactThinking(t *testing.T) {
required := claude.FullClaudeCodeMimicryBetas()
require.NotContains(t, required, claude.BetaRedactThinking)
require.Contains(t, required, claude.BetaClaudeCode)
require.Contains(t, required, claude.BetaOAuth)
require.Contains(t, required, claude.BetaInterleavedThinking)
}
func TestMergeAnthropicBetaDropping_PreservesIncomingRedactThinking(t *testing.T) {
required := claude.FullClaudeCodeMimicryBetas()
incoming := claude.BetaRedactThinking
got := mergeAnthropicBetaDropping(required, incoming, droppedBetaSet())
require.Contains(t, got, claude.BetaRedactThinking)
}
func TestDroppedBetaSet(t *testing.T) {
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
base := droppedBetaSet()

View File

@ -5445,6 +5445,12 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
flusher.Flush()
}
if !sawTerminalEvent {
if clientDisconnected && streamInterval > 0 {
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) >= streamInterval {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
}
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil

View File

@ -440,6 +440,21 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
func (s *OpenAIGatewayService) isCodexImageGenerationBridgeEnabled(ctx context.Context, account *Account, apiKey *APIKey) bool {
if override := account.CodexImageGenerationBridgeOverride(); override != nil {
return *override
}
if s != nil && s.channelService != nil && apiKey != nil && apiKey.GroupID != nil {
ch, err := s.channelService.GetChannelForGroup(ctx, *apiKey.GroupID)
if err != nil {
slog.Warn("failed to resolve codex image generation bridge channel override", "group_id", *apiKey.GroupID, "error", err)
} else if override := ch.CodexImageGenerationBridgeOverride(PlatformOpenAI); override != nil {
return *override
}
}
return s != nil && s.cfg != nil && s.cfg.Gateway.CodexImageGenerationBridgeEnabled
}
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
@ -2059,6 +2074,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if apiKey != nil {
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
}
codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey)
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
c.JSON(http.StatusForbidden, gin.H{
@ -2128,7 +2144,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("instructions", "You are a helpful coding assistant.")
}
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
@ -2139,7 +2155,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
}
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")

View File

@ -83,12 +83,14 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
gin.SetMode(gin.TestMode)
tests := []struct {
name string
allowImages bool
wantInjected bool
name string
allowImages bool
bridgeEnabled bool
wantInjected bool
}{
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
{name: "disabled group skips injection", allowImages: false, bridgeEnabled: true, wantInjected: false},
{name: "enabled group skips injection by default", allowImages: true, bridgeEnabled: false, wantInjected: false},
{name: "enabled group injects image tool when bridge enabled", allowImages: true, bridgeEnabled: true, wantInjected: true},
}
for _, tt := range tests {
@ -101,6 +103,7 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.bridgeEnabled
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
@ -117,6 +120,154 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
}
}
func TestOpenAIGatewayServiceForward_ExplicitImageToolWorksWithBridgeDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_explicit_image","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
body := []byte(`{"model":"gpt-5.4","input":"draw","stream":false,"tools":[{"type":"image_generation","format":"jpeg"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
require.Equal(t, "jpeg", gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").output_format`).String())
require.False(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").format`).Exists())
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
require.NotContains(t, instructions, "image_generation")
}
func TestOpenAIGatewayServiceForward_ChannelBridgeOverrideEnablesCodexInjection(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_channel_bridge","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
groupID := int64(4242)
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, &Channel{
ID: 9001,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
},
})
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
require.Contains(t, instructions, "image_generation")
}
func TestOpenAIGatewayService_CodexImageGenerationBridgeOverridePrecedence(t *testing.T) {
groupID := int64(4242)
tests := []struct {
name string
global bool
channel *Channel
account *Account
want bool
}{
{
name: "global default enables bridge",
global: true,
account: &Account{
Platform: PlatformOpenAI,
},
want: true,
},
{
name: "channel true overrides disabled global",
global: false,
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
}},
account: &Account{Platform: PlatformOpenAI},
want: true,
},
{
name: "channel false overrides enabled global",
global: true,
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
}},
account: &Account{Platform: PlatformOpenAI},
want: false,
},
{
name: "account false overrides channel and global true",
global: true,
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
}},
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{featureKeyCodexImageGenerationBridge: false},
},
want: false,
},
{
name: "nested account true overrides channel false",
global: false,
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
}},
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
PlatformOpenAI: map[string]any{"codex_image_generation_bridge_enabled": true},
},
},
want: true,
},
{
name: "non openai account extra is ignored",
global: false,
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{featureKeyCodexImageGenerationBridge: true},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.global
if tt.channel != nil {
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, tt.channel)
}
apiKey := &APIKey{GroupID: &groupID}
got := svc.isCodexImageGenerationBridgeEnabled(context.Background(), tt.account, apiKey)
require.Equal(t, tt.want, got)
})
}
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
@ -180,6 +331,18 @@ func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder)
}
}
func newOpenAIImageGenerationControlChannelService(groupID int64, ch *Channel) *ChannelService {
svc := &ChannelService{}
cache := newEmptyChannelCache()
if ch != nil {
cache.channelByGroupID[groupID] = ch
cache.byID[ch.ID] = ch
}
cache.loadedAt = time.Now()
svc.cache.Store(cache)
return svc
}
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)

View File

@ -90,6 +90,69 @@ type OpenAIImagesRequest struct {
bodyHash string
}
func (r *OpenAIImagesRequest) ModerationBody() []byte {
if r == nil {
return nil
}
payload := map[string]any{}
if prompt := strings.TrimSpace(r.Prompt); prompt != "" {
payload["prompt"] = prompt
}
images := r.moderationImages()
if len(images) > 0 {
payload["images"] = images
}
if len(payload) == 0 {
return nil
}
body, err := json.Marshal(payload)
if err != nil {
return nil
}
return body
}
func (r *OpenAIImagesRequest) moderationImages() []map[string]string {
if r == nil {
return nil
}
images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1)
for _, imageURL := range r.InputImageURLs {
imageURL = strings.TrimSpace(imageURL)
if imageURL != "" {
images = append(images, map[string]string{"image_url": imageURL})
}
}
for _, upload := range r.Uploads {
if dataURL := upload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" {
images = append(images, map[string]string{"image_url": maskURL})
}
if r.MaskUpload != nil {
if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" {
images = append(images, map[string]string{"image_url": dataURL})
}
}
return images
}
func (u OpenAIImagesUpload) ModerationDataURL() string {
if len(u.Data) == 0 {
return ""
}
contentType := strings.TrimSpace(u.ContentType)
if contentType == "" {
contentType = http.DetectContentType(u.Data)
}
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return ""
}
return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data))
}
func (r *OpenAIImagesRequest) IsEdits() bool {
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
}

View File

@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
InputImageURLs: []string{"https://example.com/source.png"},
MaskImageURL: "https://example.com/mask.png",
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images)
}
func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Prompt: "replace background",
Uploads: []OpenAIImagesUpload{{
FieldName: "image",
FileName: "source.png",
ContentType: "image/png",
Data: []byte("fake-image-bytes"),
}},
MaskUpload: &OpenAIImagesUpload{
FieldName: "mask",
FileName: "mask.png",
ContentType: "image/png",
Data: []byte("fake-mask-bytes"),
},
}
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
require.Equal(t, "replace background", input.Text)
require.Equal(t, []string{
"data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==",
"data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz",
}, input.Images)
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
require.Equal(t, "replace background", log.InputExcerpt)
require.NotContains(t, log.InputExcerpt, "ZmFrZS")
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -52,3 +52,47 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess
require.Equal(t, "client-id-1", info.ClientID)
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
}
func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) {
refresher := NewOpenAITokenRefresher(nil, nil)
expiresAt := time.Now().Add(time.Minute).UTC().Format(time.RFC3339)
withoutRT := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": expiresAt,
},
}
require.False(t, refresher.NeedsRefresh(withoutRT, 5*time.Minute))
withRT := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
require.True(t, refresher.NeedsRefresh(withRT, 5*time.Minute))
}
func TestOpenAITokenProvider_NoRefreshTokenExpiredAccessTokenReturnsError(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-access-token",
"expires_at": expiresAt,
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Empty(t, token)
require.Contains(t, err.Error(), "refresh_token is missing")
}

View File

@ -152,6 +152,12 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
if expiresAt != nil && !time.Now().Before(*expiresAt) {
return "", errors.New("openai access_token expired and refresh_token is missing")
}
needsRefresh = false
}
refreshFailed := false
if needsRefresh && p.refreshAPI != nil && p.executor != nil {

View File

@ -424,8 +424,9 @@ func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
@ -650,8 +651,9 @@ func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
@ -819,8 +821,9 @@ func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
@ -848,8 +851,9 @@ func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
@ -875,8 +879,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T)
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
@ -911,8 +916,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
"access_token": "fallback-token",
"refresh_token": "refresh-token",
"expires_at": expiresAt,
},
}

View File

@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct {
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
InitialRequestModel string
BeforeTurn func(turn int) error
BeforeRequest func(turn int, payload []byte, originalModel string) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
}
@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return true
}
for {
if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil {
if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil {
return err
}
}
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
return err

View File

@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if msgType != coderws.MessageText {
return payload, nil, nil
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil {
turnNo := int(completedTurns.Load()) + 1
if turnNo < 2 {
turnNo = 2
}
requestModel := usageMeta.requestModelForFrame(payload)
if requestModel == "" {
requestModel = capturedSessionModel
}
if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil {
return payload, nil, err
}
}
// 在评估策略前先刷新 capturedSessionModel客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现

View File

@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
case redeemActionRedeem:
// Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {

View File

@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
nil,
client,
nil,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{

View File

@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@ -28,6 +29,15 @@ const (
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
type ctxKeySkipRedeemAffiliate struct{}
// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level
// affiliate rebate. Used by payment fulfillment which handles rebate separately
// via applyAffiliateRebateForOrder (with audit-log deduplication).
func ContextSkipRedeemAffiliate(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true)
}
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
@ -80,6 +90,7 @@ type RedeemService struct {
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
affiliateService *AffiliateService
}
// NewRedeemService 创建兑换码服务实例
@ -91,6 +102,7 @@ func NewRedeemService(
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
affiliateService *AffiliateService,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
@ -100,6 +112,7 @@ func NewRedeemService(
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
affiliateService: affiliateService,
}
}
@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 事务提交成功后失效缓存
s.invalidateRedeemCaches(ctx, userID, redeemCode)
// 余额类正数兑换码触发邀请返利best-effort失败不影响兑换结果
if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 {
s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value)
}
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64
}
}
func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) {
if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil {
return
}
if s.affiliateService == nil {
return
}
if !s.affiliateService.IsEnabled(ctx) {
return
}
rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount)
if err != nil {
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err)
return
}
if rebate > 0 {
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID)
}
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
@ -129,6 +130,8 @@ type AuthSourceDefaultSettings struct {
LinuxDo ProviderDefaultGrantSettings
OIDC ProviderDefaultGrantSettings
WeChat ProviderDefaultGrantSettings
GitHub ProviderDefaultGrantSettings
Google ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
@ -169,6 +172,20 @@ var (
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
}
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultGitHubBalance,
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
}
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultGoogleBalance,
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
}
)
const (
@ -177,8 +194,151 @@ const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
defaultWeChatConnectFrontend = "/auth/wechat/callback"
defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize"
defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token"
defaultGitHubOAuthUserInfo = "https://api.github.com/user"
defaultGitHubOAuthEmails = "https://api.github.com/user/emails"
defaultGitHubOAuthScopes = "read:user user:email"
defaultGitHubOAuthFrontend = "/auth/oauth/callback"
defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth"
defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token"
defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo"
defaultGoogleOAuthScopes = "openid email profile"
defaultGoogleOAuthFrontend = "/auth/oauth/callback"
defaultLoginAgreementMode = "modal"
defaultLoginAgreementDate = "2026-03-31"
)
func normalizeLoginAgreementMode(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "checkbox":
return "checkbox"
default:
return defaultLoginAgreementMode
}
}
func defaultLoginAgreementDocuments() []LoginAgreementDocument {
return []LoginAgreementDocument{
{
ID: "terms",
Title: "服务条款",
ContentMD: "",
},
{
ID: "usage-policy",
Title: "使用政策",
ContentMD: "",
},
{
ID: "supported-regions",
Title: "支持的国家和地区",
ContentMD: "",
},
{
ID: "service-specific-terms",
Title: "服务特定条款",
ContentMD: "",
},
}
}
func normalizeLoginAgreementDocumentID(raw string) string {
raw = strings.ToLower(strings.TrimSpace(raw))
var b strings.Builder
lastSeparator := false
for _, r := range raw {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') {
_, _ = b.WriteRune(r)
lastSeparator = false
continue
}
if r == '-' || r == '_' || r == ' ' || r == '.' || r == '/' {
if !lastSeparator && b.Len() > 0 {
if r == '_' {
_, _ = b.WriteRune('_')
} else {
_, _ = b.WriteRune('-')
}
lastSeparator = true
}
}
}
return strings.Trim(b.String(), "-_")
}
func normalizeLoginAgreementDocuments(docs []LoginAgreementDocument) []LoginAgreementDocument {
normalized := make([]LoginAgreementDocument, 0, len(docs))
seen := make(map[string]int, len(docs))
for i, doc := range docs {
title := strings.TrimSpace(doc.Title)
content := strings.TrimSpace(doc.ContentMD)
if title == "" && content == "" {
continue
}
id := normalizeLoginAgreementDocumentID(doc.ID)
if id == "" {
sum := sha256.Sum256([]byte(fmt.Sprintf("%d:%s:%s", i, title, content)))
id = hex.EncodeToString(sum[:])[:12]
}
baseID := id
for suffix := 2; seen[id] > 0; suffix++ {
id = fmt.Sprintf("%s-%d", baseID, suffix)
}
seen[id]++
normalized = append(normalized, LoginAgreementDocument{
ID: id,
Title: title,
ContentMD: content,
})
}
return normalized
}
func parseLoginAgreementDocuments(raw string) []LoginAgreementDocument {
raw = strings.TrimSpace(raw)
if raw == "" {
return defaultLoginAgreementDocuments()
}
var docs []LoginAgreementDocument
if err := json.Unmarshal([]byte(raw), &docs); err != nil {
return defaultLoginAgreementDocuments()
}
docs = normalizeLoginAgreementDocuments(docs)
if len(docs) == 0 {
return defaultLoginAgreementDocuments()
}
return docs
}
func marshalLoginAgreementDocuments(docs []LoginAgreementDocument) (string, error) {
normalized := normalizeLoginAgreementDocuments(docs)
if len(normalized) == 0 {
normalized = defaultLoginAgreementDocuments()
}
b, err := json.Marshal(normalized)
if err != nil {
return "", fmt.Errorf("marshal login agreement documents: %w", err)
}
return string(b), nil
}
func buildLoginAgreementRevision(updatedAt string, docs []LoginAgreementDocument) string {
normalized := normalizeLoginAgreementDocuments(docs)
payload, err := json.Marshal(struct {
UpdatedAt string `json:"updated_at"`
Documents []LoginAgreementDocument `json:"documents"`
}{
UpdatedAt: strings.TrimSpace(updatedAt),
Documents: normalized,
})
if err != nil {
payload = []byte(strings.TrimSpace(updatedAt))
}
sum := sha256.Sum256(payload)
return hex.EncodeToString(sum[:])[:16]
}
func normalizeWeChatConnectModeSetting(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "mp":
@ -411,6 +571,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPasswordResetEnabled,
SettingKeyInvitationCodeEnabled,
SettingKeyTotpEnabled,
SettingKeyLoginAgreementEnabled,
SettingKeyLoginAgreementMode,
SettingKeyLoginAgreementUpdatedAt,
SettingKeyLoginAgreementDocuments,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
@ -448,6 +612,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingKeyGitHubOAuthEnabled,
SettingKeyGitHubOAuthClientID,
SettingKeyGitHubOAuthClientSecret,
SettingKeyGoogleOAuthEnabled,
SettingKeyGoogleOAuthClientID,
SettingKeyGoogleOAuthClientSecret,
SettingKeyBalanceLowNotifyEnabled,
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
@ -456,6 +626,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
SettingKeyAffiliateEnabled,
SettingKeyRiskControlEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@ -482,6 +653,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github")
googleEnabled := s.emailOAuthPublicEnabled(settings, "google")
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
// Password reset requires email verification to be enabled
@ -494,6 +667,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
settings[SettingKeyTableDefaultPageSize],
settings[SettingKeyTablePageSizeOptions],
)
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
if loginAgreementUpdatedAt == "" {
loginAgreementUpdatedAt = defaultLoginAgreementDate
}
var balanceLowNotifyThreshold float64
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
@ -509,6 +687,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PasswordResetEnabled: passwordResetEnabled,
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0,
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments),
LoginAgreementDocuments: loginAgreementDocuments,
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
@ -534,6 +717,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
GitHubOAuthEnabled: gitHubEnabled,
GoogleOAuthEnabled: googleEnabled,
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
@ -545,6 +730,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true",
}, nil
}
@ -647,43 +834,50 @@ func (s *SettingService) SetVersion(version string) {
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
// drift automatically (see setting_service_injection_test.go).
type PublicSettingsInjectionPayload struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"`
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
LoginAgreementMode string `json:"login_agreement_mode"`
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
LoginAgreementRevision string `json:"login_agreement_revision"`
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
// Feature flags — MUST match the opt-in/opt-out registry in
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
@ -692,6 +886,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
AffiliateEnabled bool `json:"affiliate_enabled"`
RiskControlEnabled bool `json:"risk_control_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
@ -710,6 +905,11 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
LoginAgreementEnabled: settings.LoginAgreementEnabled,
LoginAgreementMode: settings.LoginAgreementMode,
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
LoginAgreementRevision: settings.LoginAgreementRevision,
LoginAgreementDocuments: settings.LoginAgreementDocuments,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
@ -733,6 +933,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
@ -745,6 +947,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
RiskControlEnabled: settings.RiskControlEnabled,
}, nil
}
@ -806,6 +1009,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string
return openReady || mpReady, openReady, mpReady, mobileReady
}
func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
cfg := config.EmailOAuthProviderConfig{
AuthorizeURL: defaultGitHubOAuthAuthorize,
TokenURL: defaultGitHubOAuthToken,
UserInfoURL: defaultGitHubOAuthUserInfo,
EmailsURL: defaultGitHubOAuthEmails,
Scopes: defaultGitHubOAuthScopes,
FrontendRedirectURL: defaultGitHubOAuthFrontend,
}
if s != nil && s.cfg != nil {
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth)
}
return cfg
case "google":
cfg := config.EmailOAuthProviderConfig{
AuthorizeURL: defaultGoogleOAuthAuthorize,
TokenURL: defaultGoogleOAuthToken,
UserInfoURL: defaultGoogleOAuthUserInfo,
Scopes: defaultGoogleOAuthScopes,
FrontendRedirectURL: defaultGoogleOAuthFrontend,
}
if s != nil && s.cfg != nil {
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth)
}
return cfg
default:
return config.EmailOAuthProviderConfig{}
}
}
func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig {
base.Enabled = override.Enabled
if strings.TrimSpace(override.ClientID) != "" {
base.ClientID = strings.TrimSpace(override.ClientID)
}
if strings.TrimSpace(override.ClientSecret) != "" {
base.ClientSecret = strings.TrimSpace(override.ClientSecret)
}
if strings.TrimSpace(override.AuthorizeURL) != "" {
base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL)
}
if strings.TrimSpace(override.TokenURL) != "" {
base.TokenURL = strings.TrimSpace(override.TokenURL)
}
if strings.TrimSpace(override.UserInfoURL) != "" {
base.UserInfoURL = strings.TrimSpace(override.UserInfoURL)
}
if strings.TrimSpace(override.EmailsURL) != "" {
base.EmailsURL = strings.TrimSpace(override.EmailsURL)
}
if strings.TrimSpace(override.Scopes) != "" {
base.Scopes = strings.TrimSpace(override.Scopes)
}
if strings.TrimSpace(override.RedirectURL) != "" {
base.RedirectURL = strings.TrimSpace(override.RedirectURL)
}
if strings.TrimSpace(override.FrontendRedirectURL) != "" {
base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL)
}
return base
}
func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool {
cfg := s.effectiveEmailOAuthConfig(settings, provider)
return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != ""
}
func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig {
cfg := s.emailOAuthBaseConfig(provider)
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github":
if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok {
cfg.Enabled = raw == "true"
}
cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID)
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret)
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL)
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend)
case "google":
if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok {
cfg.Enabled = raw == "true"
}
cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID)
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret)
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL)
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend)
}
return cfg
}
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
// array string, returning only items with visibility != "admin".
func filterUserVisibleMenuItems(raw string) json.RawMessage {
@ -1052,6 +1347,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
if settings.WeChatConnectFrontendRedirectURL == "" {
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
}
settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL)
settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL)
if settings.GitHubOAuthFrontendRedirectURL == "" {
settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend
}
settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL)
settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL)
if settings.GoogleOAuthFrontendRedirectURL == "" {
settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend
}
updates := make(map[string]string)
@ -1068,6 +1373,19 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyFrontendURL] = settings.FrontendURL
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode)
settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt)
if settings.LoginAgreementUpdatedAt == "" {
settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate
}
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments)
if err != nil {
return nil, err
}
updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled)
updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode
updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt
updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
@ -1121,6 +1439,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
// GitHub / Google 邮箱快捷登录
updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled)
updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID)
updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL
updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL
if settings.GitHubOAuthClientSecret != "" {
updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret)
}
updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled)
updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID)
updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL
updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL
if settings.GoogleOAuthClientSecret != "" {
updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret)
}
// WeChat Connect OAuth 登录
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
@ -1232,6 +1566,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Affiliate (邀请返利) feature switch
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
// 风控中心功能开关
updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled)
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@ -1273,17 +1610,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
settings.GitHub.Subscriptions,
settings.Google.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return nil, err
}
}
updates := make(map[string]string, 21)
updates := make(map[string]string, 31)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
return updates, nil
}
@ -1362,6 +1703,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context,
return nil
}
func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider != "github" && provider != "google" {
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found")
}
keys := []string{
SettingKeyGitHubOAuthEnabled,
SettingKeyGitHubOAuthClientID,
SettingKeyGitHubOAuthClientSecret,
SettingKeyGitHubOAuthRedirectURL,
SettingKeyGitHubOAuthFrontendRedirectURL,
SettingKeyGoogleOAuthEnabled,
SettingKeyGoogleOAuthClientID,
SettingKeyGoogleOAuthClientSecret,
SettingKeyGoogleOAuthRedirectURL,
SettingKeyGoogleOAuthFrontendRedirectURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err)
}
cfg := s.effectiveEmailOAuthConfig(settings, provider)
if !cfg.Enabled {
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
if strings.TrimSpace(cfg.ClientID) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(cfg.ClientSecret) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
for label, rawURL := range map[string]string{
"authorize": cfg.AuthorizeURL,
"token": cfg.TokenURL,
"userinfo": cfg.UserInfoURL,
"redirect": cfg.RedirectURL,
} {
if strings.TrimSpace(rawURL) == "" {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid")
}
}
if strings.TrimSpace(cfg.EmailsURL) != "" {
if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid")
}
}
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
return cfg, nil
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
@ -1534,6 +1930,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting.
func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems)
if err != nil {
return "[]"
}
return value
}
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
@ -1711,6 +2116,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
SettingKeyAuthSourceDefaultWeChatSubscriptions,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
SettingKeyAuthSourceDefaultGitHubBalance,
SettingKeyAuthSourceDefaultGitHubConcurrency,
SettingKeyAuthSourceDefaultGitHubSubscriptions,
SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
SettingKeyAuthSourceDefaultGoogleBalance,
SettingKeyAuthSourceDefaultGoogleConcurrency,
SettingKeyAuthSourceDefaultGoogleSubscriptions,
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
@ -1724,6 +2139,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
@ -1793,6 +2210,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
}
}
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(defaultLoginAgreementDocuments())
if err != nil {
return err
}
// 初始化默认设置
defaults := map[string]string{
@ -1800,6 +2221,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeyLoginAgreementEnabled: "false",
SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
@ -1824,6 +2249,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyWeChatConnectScopes: "snsapi_login",
SettingKeyWeChatConnectRedirectURL: "",
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
SettingKeyGitHubOAuthEnabled: "false",
SettingKeyGitHubOAuthClientID: "",
SettingKeyGitHubOAuthClientSecret: "",
SettingKeyGitHubOAuthRedirectURL: "",
SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
SettingKeyGoogleOAuthEnabled: "false",
SettingKeyGoogleOAuthClientID: "",
SettingKeyGoogleOAuthClientSecret: "",
SettingKeyGoogleOAuthRedirectURL: "",
SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyOIDCConnectClientID: "",
@ -1874,6 +2309,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultGitHubBalance: "0",
SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultGoogleBalance: "0",
SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
SettingKeyForceEmailOnThirdPartySignup: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
@ -1903,6 +2348,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Affiliate (邀请返利) feature (default disabled; opt-in)
SettingKeyAffiliateEnabled: "false",
// 风控中心功能(默认关闭,显式启用)
SettingKeyRiskControlEnabled: "false",
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@ -1923,6 +2371,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
if loginAgreementUpdatedAt == "" {
loginAgreementUpdatedAt = defaultLoginAgreementDate
}
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
@ -1932,6 +2385,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
FrontendURL: settings[SettingKeyFrontendURL],
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true",
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
@ -2173,6 +2630,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github")
result.GitHubOAuthEnabled = gitHubEffective.Enabled
result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID)
result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret)
result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != ""
result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL)
result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL)
googleEffective := s.effectiveEmailOAuthConfig(settings, "google")
result.GoogleOAuthEnabled = googleEffective.Enabled
result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID)
result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret)
result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != ""
result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL)
result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL)
// WeChat Connect 设置:
// - 优先读取 DB 系统设置
// - 缺失时回退到 config/env保持升级兼容
@ -2242,6 +2715,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Affiliate (邀请返利) feature (default: disabled; strict true)
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
// 风控中心功能(默认关闭,严格 true 才启用)
result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true"
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]

View File

@ -20,6 +20,10 @@ type SystemSettings struct {
FrontendURL string
InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证
LoginAgreementEnabled bool
LoginAgreementMode string
LoginAgreementUpdatedAt string
LoginAgreementDocuments []LoginAgreementDocument
SMTPHost string
SMTPPort int
@ -89,6 +93,20 @@ type SystemSettings struct {
OIDCConnectUserInfoIDPath string
OIDCConnectUserInfoUsernamePath string
// GitHub / Google 邮箱快捷登录
GitHubOAuthEnabled bool
GitHubOAuthClientID string
GitHubOAuthClientSecret string
GitHubOAuthClientSecretConfigured bool
GitHubOAuthRedirectURL string
GitHubOAuthFrontendRedirectURL string
GoogleOAuthEnabled bool
GoogleOAuthClientID string
GoogleOAuthClientSecret string
GoogleOAuthClientSecretConfigured bool
GoogleOAuthRedirectURL string
GoogleOAuthFrontendRedirectURL string
SiteName string
SiteLogo string
SiteSubtitle string
@ -106,6 +124,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
RiskControlEnabled bool
AffiliateEnabled bool
AffiliateRebateRate float64
AffiliateRebateFreezeHours int
@ -190,6 +209,11 @@ type PublicSettings struct {
PasswordResetEnabled bool
InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证
LoginAgreementEnabled bool
LoginAgreementMode string
LoginAgreementUpdatedAt string
LoginAgreementRevision string
LoginAgreementDocuments []LoginAgreementDocument
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
@ -217,6 +241,8 @@ type PublicSettings struct {
PaymentEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
GitHubOAuthEnabled bool
GoogleOAuthEnabled bool
Version string
BalanceLowNotifyEnabled bool
@ -233,6 +259,15 @@ type PublicSettings struct {
// Affiliate (邀请返利) feature toggle
AffiliateEnabled bool `json:"affiliate_enabled"`
// 风控中心功能开关
RiskControlEnabled bool `json:"risk_control_enabled"`
}
type LoginAgreementDocument struct {
ID string `json:"id"`
Title string `json:"title"`
ContentMD string `json:"content_md"`
}
type WeChatConnectOAuthConfig struct {

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"strings"
"time"
)
@ -95,6 +96,9 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
// NeedsRefresh 检查token是否需要刷新
// expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
return false
}
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return account.IsRateLimited()

View File

@ -97,6 +97,8 @@ type UserRepository interface {
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error)
BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error)
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups幂等冲突忽略

View File

@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
out := make([]UserAuthIdentityRecord, len(m.identities))

View File

@ -515,6 +515,7 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
NewContentModerationService,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,

View File

@ -0,0 +1,27 @@
ALTER TABLE users
DROP CONSTRAINT IF EXISTS users_signup_source_check;
ALTER TABLE users
ADD CONSTRAINT users_signup_source_check
CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
ALTER TABLE auth_identities
DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check;
ALTER TABLE auth_identities
ADD CONSTRAINT auth_identities_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
ALTER TABLE auth_identity_channels
DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check;
ALTER TABLE auth_identity_channels
ADD CONSTRAINT auth_identity_channels_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
ALTER TABLE pending_auth_sessions
DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check;
ALTER TABLE pending_auth_sessions
ADD CONSTRAINT pending_auth_sessions_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));

View File

@ -0,0 +1,45 @@
-- 风控中心内容审计配置与记录
INSERT INTO settings (key, value, updated_at)
VALUES ('risk_control_enabled', 'false', NOW())
ON CONFLICT (key) DO NOTHING;
CREATE TABLE IF NOT EXISTS content_moderation_logs (
id BIGSERIAL PRIMARY KEY,
request_id VARCHAR(128) NOT NULL DEFAULT '',
user_id BIGINT REFERENCES users(id) ON DELETE SET NULL,
user_email VARCHAR(255) NOT NULL DEFAULT '',
api_key_id BIGINT REFERENCES api_keys(id) ON DELETE SET NULL,
api_key_name VARCHAR(100) NOT NULL DEFAULT '',
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
group_name VARCHAR(255) NOT NULL DEFAULT '',
endpoint VARCHAR(128) NOT NULL DEFAULT '',
provider VARCHAR(64) NOT NULL DEFAULT '',
model VARCHAR(255) NOT NULL DEFAULT '',
mode VARCHAR(32) NOT NULL DEFAULT '',
action VARCHAR(32) NOT NULL DEFAULT '',
flagged BOOLEAN NOT NULL DEFAULT FALSE,
highest_category VARCHAR(64) NOT NULL DEFAULT '',
highest_score DECIMAL(8, 6) NOT NULL DEFAULT 0,
category_scores JSONB NOT NULL DEFAULT '{}'::jsonb,
threshold_snapshot JSONB NOT NULL DEFAULT '{}'::jsonb,
input_excerpt TEXT NOT NULL DEFAULT '',
upstream_latency_ms INT,
error TEXT NOT NULL DEFAULT '',
violation_count INT NOT NULL DEFAULT 0,
auto_banned BOOLEAN NOT NULL DEFAULT FALSE,
email_sent BOOLEAN NOT NULL DEFAULT FALSE,
queue_delay_ms INT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS violation_count INT NOT NULL DEFAULT 0;
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS auto_banned BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS email_sent BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS queue_delay_ms INT;
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_created_at ON content_moderation_logs(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_group_created_at ON content_moderation_logs(group_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_flagged_created_at ON content_moderation_logs(flagged, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_user_created_at ON content_moderation_logs(user_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_api_key_created_at ON content_moderation_logs(api_key_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_endpoint_created_at ON content_moderation_logs(endpoint, created_at DESC);

View File

@ -142,3 +142,16 @@ func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T)
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
require.NotContains(t, sql, "detail::jsonb")
}
func TestMigration135AllowsGitHubAndGoogleAuthProviders(t *testing.T) {
content, err := FS.ReadFile("135_allow_email_oauth_provider_types.sql")
require.NoError(t, err)
sql := string(content)
require.Contains(t, sql, "users_signup_source_check")
require.Contains(t, sql, "auth_identities_provider_type_check")
require.Contains(t, sql, "auth_identity_channels_provider_type_check")
require.Contains(t, sql, "pending_auth_sessions_provider_type_check")
require.Contains(t, sql, "'github'")
require.Contains(t, sql, "'google'")
}

View File

@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.2-alpine
ARG GOLANG_IMAGE=golang:1.26.3-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn

View File

@ -202,6 +202,14 @@ gateway:
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI请谨慎开启。
force_codex_cli: false
# Enable Codex image-generation bridge injection for /openai/v1/responses.
# 是否为 Codex /responses 请求自动注入 image_generation 工具与桥接指令。
#
# Default false keeps text-only Codex requests text-only. Explicit client-provided
# image_generation tools are still forwarded when the group allows image generation.
# 默认 false保持纯文本 Codex 请求不被改写;客户端显式提供 image_generation tool 时,
# 仍会在分组允许图片生成的情况下正常转发。
codex_image_generation_bridge_enabled: false
# Optional: template file used to build the final top-level Codex `instructions`.
# 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。
#

View File

@ -16,6 +16,8 @@ import type {
TempUnschedulableStatus,
AdminDataPayload,
AdminDataImportResult,
CodexSessionImportRequest,
CodexSessionImportResult,
CheckMixedChannelRequest,
CheckMixedChannelResponse
} from '@/types'
@ -547,6 +549,11 @@ export async function importData(payload: {
return data
}
export async function importCodexSession(payload: CodexSessionImportRequest): Promise<CodexSessionImportResult> {
const { data } = await apiClient.post<CodexSessionImportResult>('/admin/accounts/import/codex-session', payload)
return data
}
/**
* Get Antigravity default model mapping from backend
* @returns Default model mapping (from -> to)
@ -663,6 +670,7 @@ export const accountsAPI = {
syncFromCrs,
exportData,
importData,
importCodexSession,
getAntigravityDefaultModelMapping,
batchClearError,
batchRefresh,

Some files were not shown because too many files have changed in this diff Show More