chore: merge upstream v0.1.124-125, keep Windsurf/Antigravity customizations
Upstream changes: - feat: 邮箱 + GitHub + Google OAuth 快捷登录 - feat: Codex image bridge 开关 - feat: 内容审核 (content moderation) — 新增 contentModerationService/Handler - feat: redeem code 返利、批量并发 API、markdown 页面渲染 - feat: 登录注册条款确认 - fix(security): pages API 加 JWT + 可见性校验 - fix: 修复 markdown 页面图片路径 - fix(gateway): 不再默认注入 redact thinking beta - fix: 稳定 anthropic passthrough 超时错误 - chore: VERSION 升到 0.1.125 + golang:1.26.3-alpine Conflict resolutions: - Dockerfile/backend/Dockerfile: 取 upstream golang:1.26.3-alpine - backend/go.mod: 取 upstream term v0.42.0,保留定制 protobuf v1.36.10 - frontend/src/api/admin/index.ts: 并集 (windsurf + riskControl) - backend/cmd/server/wire_gen.go: 接 upstream contentModeration*,保留 windsurfHandler/windsurfGatewayService/billingCacheService/requestEventBus;并通过 wire 重生成 - frontend/src/views/admin/AccountsView.vue: 采用 upstream 双层布局 + OpenAI Meta,保留 is_enterprise prop 和 Windsurf tier badge Note: - WIP commit (de048fad) preserved Windsurf tier access service / NLU extractor / ops log stream / Google OAuth login modal et al before merge. - 3 pre-existing go vet issues in test files (NewOpsHandler, RegisterGatewayRoutes, DefaultCLIProductVersion) are unrelated to this merge — leftover from local customization refactors; production code (go build ./...) passes.
This commit is contained in:
commit
7347dfffc1
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.2'
|
||||
go version | grep -q 'go1.26.3'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@ -60,7 +60,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.2'
|
||||
go version | grep -q 'go1.26.3'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.2'
|
||||
go version | grep -q 'go1.26.3'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.2'
|
||||
go version | grep -q 'go1.26.3'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.3-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
|
||||
@ -72,8 +72,8 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
|
||||
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
|
||||
@ -71,8 +71,8 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
|
||||
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/register?aff=SUB2API">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
|
||||
@ -71,8 +71,8 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
|
||||
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM golang:1.26-alpine
|
||||
FROM golang:1.26.3-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.123
|
||||
0.1.125
|
||||
|
||||
@ -8,11 +8,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
@ -23,9 +18,14 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
)
|
||||
|
||||
@ -74,7 +74,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
|
||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -136,26 +136,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
|
||||
usageCache := service.NewUsageCache()
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
windsurfLSService := service.ProvideWindsurfLSService(configConfig)
|
||||
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
|
||||
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
|
||||
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
@ -237,19 +236,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
||||
contentModerationRepository := repository.NewContentModerationRepository(db)
|
||||
contentModerationHashCache := repository.NewContentModerationHashCache(redisClient)
|
||||
contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService)
|
||||
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
|
||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||
windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService)
|
||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||
windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository)
|
||||
windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository)
|
||||
windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService)
|
||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, windsurfHandler, affiliateHandler)
|
||||
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
||||
@ -274,6 +277,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
|
||||
application := &Application{
|
||||
@ -485,6 +489,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"WindsurfLSService", func() error {
|
||||
if windsurfLS != nil {
|
||||
windsurfLS.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@ -16,6 +16,8 @@ import (
|
||||
|
||||
var authProviderTypes = map[string]struct{}{
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
|
||||
@ -83,10 +83,10 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
|
||||
require.Equal(t, 1, signupSource.Validators)
|
||||
|
||||
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
|
||||
require.NoError(t, validator(value))
|
||||
}
|
||||
require.Error(t, validator("github"))
|
||||
require.Error(t, validator("unknown"))
|
||||
}
|
||||
|
||||
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
|
||||
|
||||
@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
|
||||
field.String("signup_source").
|
||||
Validate(func(value string) error {
|
||||
switch value {
|
||||
case "email", "linuxdo", "wechat", "oidc":
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
|
||||
}
|
||||
}).
|
||||
Default("email"),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
connectrpc.com/connect v1.19.2
|
||||
@ -21,6 +21,7 @@ require (
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/imroc/req/v3 v3.57.0
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pquerna/otp v1.5.0
|
||||
@ -40,11 +41,11 @@ require (
|
||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/image v0.39.0
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/net v0.53.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/term v0.42.0
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@ -112,7 +113,6 @@ require (
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
@ -176,7 +176,7 @@ require (
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
golang.org/x/tools v0.43.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
|
||||
@ -187,8 +187,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@ -224,8 +222,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@ -259,8 +255,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@ -290,8 +284,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@ -324,8 +316,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
@ -417,16 +407,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
|
||||
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -438,10 +428,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
|
||||
@ -79,6 +79,8 @@ type Config struct {
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
||||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||||
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
|
||||
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
@ -248,6 +250,19 @@ type OIDCConnectConfig struct {
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type EmailOAuthProviderConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
EmailsURL string `mapstructure:"emails_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"`
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultWeChatConnectMode = "open"
|
||||
defaultWeChatConnectScopes = "snsapi_login"
|
||||
@ -619,6 +634,9 @@ type GatewayConfig struct {
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
|
||||
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
|
||||
// CodexImageGenerationBridgeEnabled: 是否为 Codex `/v1/responses` 自动注入 image_generation 工具和桥接指令。
|
||||
// 默认关闭,避免纯文本 Codex 请求被意外改写;显式携带 image_generation 工具的请求仍按分组能力转发。
|
||||
CodexImageGenerationBridgeEnabled bool `mapstructure:"codex_image_generation_bridge_enabled"`
|
||||
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
|
||||
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
|
||||
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
|
||||
@ -1773,6 +1791,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.force_codex_cli", false)
|
||||
viper.SetDefault("gateway.codex_image_generation_bridge_enabled", false)
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
|
||||
1045
backend/internal/handler/admin/account_codex_import.go
Normal file
1045
backend/internal/handler/admin/account_codex_import.go
Normal file
File diff suppressed because it is too large
Load Diff
344
backend/internal/handler/admin/account_codex_import_test.go
Normal file
344
backend/internal/handler/admin/account_codex_import_test.go
Normal file
@ -0,0 +1,344 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCodexSessionImportEntriesSupportsRawTokenJSONAndArray(t *testing.T) {
|
||||
token1 := "raw-access-token-1"
|
||||
token2 := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
|
||||
"email": "json@example.com",
|
||||
})
|
||||
token3 := "raw-access-token-3"
|
||||
|
||||
req := CodexSessionImportRequest{
|
||||
Content: fmt.Sprintf("%s\n{\"accessToken\":%q}\n[%q]", token1, token2, token3),
|
||||
}
|
||||
|
||||
entries, err := parseCodexSessionImportEntries(req)
|
||||
if err != nil {
|
||||
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
|
||||
}
|
||||
if len(entries) != 3 {
|
||||
t.Fatalf("len(entries) = %d, want 3", len(entries))
|
||||
}
|
||||
|
||||
first, err := normalizeCodexImportEntry(entries[0])
|
||||
if err != nil {
|
||||
t.Fatalf("normalize raw token error = %v", err)
|
||||
}
|
||||
if first.Credentials["access_token"] != token1 {
|
||||
t.Fatalf("raw token access_token = %v, want %s", first.Credentials["access_token"], token1)
|
||||
}
|
||||
|
||||
second, err := normalizeCodexImportEntry(entries[1])
|
||||
if err != nil {
|
||||
t.Fatalf("normalize json token error = %v", err)
|
||||
}
|
||||
if second.Email != "json@example.com" {
|
||||
t.Fatalf("email = %q, want json@example.com", second.Email)
|
||||
}
|
||||
|
||||
third, err := normalizeCodexImportEntry(entries[2])
|
||||
if err != nil {
|
||||
t.Fatalf("normalize array token error = %v", err)
|
||||
}
|
||||
if third.Credentials["access_token"] != token3 {
|
||||
t.Fatalf("array token access_token = %v, want %s", third.Credentials["access_token"], token3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexSessionImportEntriesFallsBackToLineModeForMixedJSONAndToken(t *testing.T) {
|
||||
req := CodexSessionImportRequest{
|
||||
Content: "{\"accessToken\":\"json-line-token\"}\nraw-line-token",
|
||||
}
|
||||
|
||||
entries, err := parseCodexSessionImportEntries(req)
|
||||
if err != nil {
|
||||
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("len(entries) = %d, want 2", len(entries))
|
||||
}
|
||||
|
||||
first, err := normalizeCodexImportEntry(entries[0])
|
||||
if err != nil {
|
||||
t.Fatalf("normalize json line error = %v", err)
|
||||
}
|
||||
if first.Credentials["access_token"] != "json-line-token" {
|
||||
t.Fatalf("json line access_token = %v, want json-line-token", first.Credentials["access_token"])
|
||||
}
|
||||
|
||||
second, err := normalizeCodexImportEntry(entries[1])
|
||||
if err != nil {
|
||||
t.Fatalf("normalize raw line error = %v", err)
|
||||
}
|
||||
if second.Credentials["access_token"] != "raw-line-token" {
|
||||
t.Fatalf("raw line access_token = %v, want raw-line-token", second.Credentials["access_token"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexSessionJSONExtractsCredentialsAndIgnoresSessionToken(t *testing.T) {
|
||||
accessToken := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
|
||||
"email": "claim@example.com",
|
||||
"https://api.openai.com/auth": map[string]any{
|
||||
"chatgpt_account_id": "acct-from-claim",
|
||||
"chatgpt_user_id": "user-from-claim",
|
||||
"chatgpt_plan_type": "plus",
|
||||
"poid": "org-from-claim",
|
||||
},
|
||||
})
|
||||
raw := map[string]any{
|
||||
"user": map[string]any{
|
||||
"id": "user-from-json",
|
||||
"name": "Sup OO",
|
||||
"email": "json@example.com",
|
||||
"image": "https://example.com/avatar.png",
|
||||
},
|
||||
"account": map[string]any{
|
||||
"id": "acct-from-json",
|
||||
"planType": "free",
|
||||
},
|
||||
"accessToken": accessToken,
|
||||
"sessionToken": "secret-session-token",
|
||||
"expires": "2026-08-05T13:40:42.836Z",
|
||||
}
|
||||
|
||||
item, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: raw})
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeCodexImportEntry error = %v", err)
|
||||
}
|
||||
if item.Credentials["access_token"] != accessToken {
|
||||
t.Fatalf("access_token not stored")
|
||||
}
|
||||
if item.Credentials["email"] != "json@example.com" {
|
||||
t.Fatalf("email = %v, want json@example.com", item.Credentials["email"])
|
||||
}
|
||||
if item.Credentials["chatgpt_account_id"] != "acct-from-json" {
|
||||
t.Fatalf("chatgpt_account_id = %v, want acct-from-json", item.Credentials["chatgpt_account_id"])
|
||||
}
|
||||
if item.Credentials["chatgpt_user_id"] != "user-from-json" {
|
||||
t.Fatalf("chatgpt_user_id = %v, want user-from-json", item.Credentials["chatgpt_user_id"])
|
||||
}
|
||||
if item.Credentials["plan_type"] != "free" {
|
||||
t.Fatalf("plan_type = %v, want free", item.Credentials["plan_type"])
|
||||
}
|
||||
if _, ok := item.Credentials["session_token"]; ok {
|
||||
t.Fatalf("session_token should not be written to credentials")
|
||||
}
|
||||
if item.Extra["session_token_present"] != true {
|
||||
t.Fatalf("session_token_present = %v, want true", item.Extra["session_token_present"])
|
||||
}
|
||||
if item.Extra["session_expires_at"] != "2026-08-05T13:40:42Z" {
|
||||
t.Fatalf("session_expires_at = %v", item.Extra["session_expires_at"])
|
||||
}
|
||||
if item.TokenExpiresAt == nil {
|
||||
t.Fatalf("TokenExpiresAt should be parsed from accessToken")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCodexImportCredentialsClearsStaleRefreshFieldsWhenIncomingHasNoRefreshToken(t *testing.T) {
|
||||
existing := map[string]any{
|
||||
"access_token": "old-access-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"client_id": "old-client-id",
|
||||
"id_token": "old-id-token",
|
||||
"model_mapping": map[string]any{"from": "existing"},
|
||||
"chatgpt_account_id": "acct-old",
|
||||
"unrelated_existing": "keep",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
"access_token": "new-access-token",
|
||||
"expires_at": "2026-08-05T13:40:42Z",
|
||||
"chatgpt_account_id": "acct-new",
|
||||
}
|
||||
item := &codexImportAccount{
|
||||
AccessToken: "new-access-token",
|
||||
}
|
||||
|
||||
merged := mergeCodexImportCredentials(existing, incoming, item)
|
||||
|
||||
if merged["access_token"] != "new-access-token" {
|
||||
t.Fatalf("access_token = %v, want new-access-token", merged["access_token"])
|
||||
}
|
||||
if merged["chatgpt_account_id"] != "acct-new" {
|
||||
t.Fatalf("chatgpt_account_id = %v, want acct-new", merged["chatgpt_account_id"])
|
||||
}
|
||||
if _, ok := merged["refresh_token"]; ok {
|
||||
t.Fatalf("refresh_token should be cleared")
|
||||
}
|
||||
if _, ok := merged["client_id"]; ok {
|
||||
t.Fatalf("client_id should be cleared")
|
||||
}
|
||||
if _, ok := merged["id_token"]; ok {
|
||||
t.Fatalf("id_token should be cleared")
|
||||
}
|
||||
if merged["unrelated_existing"] != "keep" {
|
||||
t.Fatalf("unrelated_existing = %v, want keep", merged["unrelated_existing"])
|
||||
}
|
||||
if _, ok := merged["model_mapping"]; !ok {
|
||||
t.Fatalf("model_mapping should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCodexImportCredentialsKeepsRefreshFieldsWhenIncomingHasRefreshToken(t *testing.T) {
|
||||
existing := map[string]any{
|
||||
"refresh_token": "old-refresh-token",
|
||||
"client_id": "old-client-id",
|
||||
"id_token": "old-id-token",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"client_id": "new-client-id",
|
||||
"id_token": "new-id-token",
|
||||
}
|
||||
item := &codexImportAccount{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token",
|
||||
}
|
||||
|
||||
merged := mergeCodexImportCredentials(existing, incoming, item)
|
||||
|
||||
if merged["refresh_token"] != "new-refresh-token" {
|
||||
t.Fatalf("refresh_token = %v, want new-refresh-token", merged["refresh_token"])
|
||||
}
|
||||
if merged["client_id"] != "new-client-id" {
|
||||
t.Fatalf("client_id = %v, want new-client-id", merged["client_id"])
|
||||
}
|
||||
if merged["id_token"] != "new-id-token" {
|
||||
t.Fatalf("id_token = %v, want new-id-token", merged["id_token"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexImportRejectsExpiredAccessToken(t *testing.T) {
|
||||
expiredToken := buildCodexImportTestJWT(t, time.Now().Add(-time.Hour), map[string]any{})
|
||||
|
||||
_, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: expiredToken})
|
||||
if err == nil {
|
||||
t.Fatal("normalizeCodexImportEntry error = nil, want expired token error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "已过期") {
|
||||
t.Fatalf("error = %v, want expired token message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexImportExpiryForNoRefreshTokenUsesTokenExpiry(t *testing.T) {
|
||||
tokenExpiresAt := time.Now().Add(time.Hour).UTC()
|
||||
item := &codexImportAccount{
|
||||
AccessToken: "access-token",
|
||||
Credentials: map[string]any{"access_token": "access-token"},
|
||||
TokenExpiresAt: &tokenExpiresAt,
|
||||
WarningTexts: []string{},
|
||||
}
|
||||
disabled := false
|
||||
req := CodexSessionImportRequest{AutoPauseOnExpired: &disabled}
|
||||
|
||||
accountExpiresAt, credentialExpiresAt, autoPause, warnings, err := resolveCodexImportExpiry(req, item)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveCodexImportExpiry error = %v", err)
|
||||
}
|
||||
if accountExpiresAt == nil || *accountExpiresAt != tokenExpiresAt.Unix() {
|
||||
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, tokenExpiresAt.Unix())
|
||||
}
|
||||
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != tokenExpiresAt.Unix() {
|
||||
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, tokenExpiresAt)
|
||||
}
|
||||
if autoPause == nil || !*autoPause {
|
||||
t.Fatalf("autoPause = %v, want true", autoPause)
|
||||
}
|
||||
if len(warnings) == 0 {
|
||||
t.Fatalf("warnings should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexImportExpiryForNoRefreshTokenRequiresExpiry(t *testing.T) {
|
||||
item := &codexImportAccount{
|
||||
AccessToken: "opaque-access-token",
|
||||
Credentials: map[string]any{"access_token": "opaque-access-token"},
|
||||
WarningTexts: []string{},
|
||||
}
|
||||
|
||||
_, _, _, _, err := resolveCodexImportExpiry(CodexSessionImportRequest{}, item)
|
||||
if err == nil {
|
||||
t.Fatal("resolveCodexImportExpiry error = nil, want missing expiry error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "无法解析 accessToken 过期时间") {
|
||||
t.Fatalf("error = %v, want missing expiry message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexImportExpiryForNoRefreshTokenUsesEarlierRequestExpiry(t *testing.T) {
|
||||
tokenExpiresAt := time.Now().Add(2 * time.Hour).UTC()
|
||||
requestExpiresAt := time.Now().Add(time.Hour).UTC()
|
||||
item := &codexImportAccount{
|
||||
AccessToken: "access-token",
|
||||
Credentials: map[string]any{"access_token": "access-token"},
|
||||
TokenExpiresAt: &tokenExpiresAt,
|
||||
WarningTexts: []string{},
|
||||
}
|
||||
reqUnix := requestExpiresAt.Unix()
|
||||
req := CodexSessionImportRequest{ExpiresAt: &reqUnix}
|
||||
|
||||
accountExpiresAt, credentialExpiresAt, _, _, err := resolveCodexImportExpiry(req, item)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveCodexImportExpiry error = %v", err)
|
||||
}
|
||||
if accountExpiresAt == nil || *accountExpiresAt != requestExpiresAt.Unix() {
|
||||
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, requestExpiresAt.Unix())
|
||||
}
|
||||
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != requestExpiresAt.Unix() {
|
||||
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, requestExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexIdentityKeysPreferStrongIdentifiers(t *testing.T) {
|
||||
keys := buildCodexIdentityKeys("acct-1", "user-1", "same@example.com", "token")
|
||||
for _, key := range keys {
|
||||
if strings.HasPrefix(key, "email:") {
|
||||
t.Fatalf("strong identity should not include email fallback: %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
keys = buildCodexIdentityKeys("", "", "same@example.com", "token")
|
||||
hasEmail := false
|
||||
for _, key := range keys {
|
||||
if key == "email:same@example.com" {
|
||||
hasEmail = true
|
||||
}
|
||||
}
|
||||
if !hasEmail {
|
||||
t.Fatalf("weak identity should include email fallback: %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func buildCodexImportTestJWT(t *testing.T, exp time.Time, extraClaims map[string]any) string {
|
||||
t.Helper()
|
||||
header := map[string]any{
|
||||
"alg": "none",
|
||||
"typ": "JWT",
|
||||
}
|
||||
claims := map[string]any{
|
||||
"sub": "user-from-sub",
|
||||
"exp": exp.Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
for k, v := range extraClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal header: %v", err)
|
||||
}
|
||||
claimBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(headerBytes) + "." + base64.RawURLEncoding.EncodeToString(claimBytes) + "."
|
||||
}
|
||||
@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||
return len(userIDs), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
238
backend/internal/handler/admin/content_moderation_handler.go
Normal file
238
backend/internal/handler/admin/content_moderation_handler.go
Normal file
@ -0,0 +1,238 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ContentModerationHandler struct {
|
||||
service *service.ContentModerationService
|
||||
}
|
||||
|
||||
func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler {
|
||||
return &ContentModerationHandler{service: svc}
|
||||
}
|
||||
|
||||
type contentModerationConfigRequest struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
APIKeysMode string `json:"api_keys_mode"`
|
||||
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
}
|
||||
|
||||
type contentModerationAPIKeyTestRequest struct {
|
||||
APIKeys []string `json:"api_keys"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Model string `json:"model"`
|
||||
TimeoutMS int `json:"timeout_ms"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
type contentModerationHashRequest struct {
|
||||
InputHash string `json:"input_hash"`
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) GetConfig(c *gin.Context) {
|
||||
cfg, err := h.service.GetConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||
var req contentModerationConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{
|
||||
Enabled: req.Enabled,
|
||||
Mode: req.Mode,
|
||||
BaseURL: req.BaseURL,
|
||||
Model: req.Model,
|
||||
APIKey: req.APIKey,
|
||||
APIKeys: req.APIKeys,
|
||||
APIKeysMode: req.APIKeysMode,
|
||||
DeleteAPIKeyHashes: req.DeleteAPIKeyHashes,
|
||||
ClearAPIKey: req.ClearAPIKey,
|
||||
TimeoutMS: req.TimeoutMS,
|
||||
SampleRate: req.SampleRate,
|
||||
AllGroups: req.AllGroups,
|
||||
GroupIDs: req.GroupIDs,
|
||||
RecordNonHits: req.RecordNonHits,
|
||||
WorkerCount: req.WorkerCount,
|
||||
QueueSize: req.QueueSize,
|
||||
BlockStatus: req.BlockStatus,
|
||||
BlockMessage: req.BlockMessage,
|
||||
EmailOnHit: req.EmailOnHit,
|
||||
AutoBanEnabled: req.AutoBanEnabled,
|
||||
BanThreshold: req.BanThreshold,
|
||||
ViolationWindowHours: req.ViolationWindowHours,
|
||||
RetryCount: req.RetryCount,
|
||||
HitRetentionDays: req.HitRetentionDays,
|
||||
NonHitRetentionDays: req.NonHitRetentionDays,
|
||||
PreHashCheckEnabled: req.PreHashCheckEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) {
|
||||
var req contentModerationAPIKeyTestRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{
|
||||
APIKeys: req.APIKeys,
|
||||
BaseURL: req.BaseURL,
|
||||
Model: req.Model,
|
||||
TimeoutMS: req.TimeoutMS,
|
||||
Prompt: req.Prompt,
|
||||
Images: req.Images,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) GetStatus(c *gin.Context) {
|
||||
status, err := h.service.GetStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) ListLogs(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := service.ContentModerationLogFilter{
|
||||
Pagination: pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
SortOrder: pagination.SortOrderDesc,
|
||||
},
|
||||
Result: c.Query("result"),
|
||||
Endpoint: c.Query("endpoint"),
|
||||
Search: c.Query("search"),
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("group_id")); raw != "" {
|
||||
groupID, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || groupID <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &groupID
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("from")); raw != "" {
|
||||
t, _, err := parseContentModerationDate(raw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid from")
|
||||
return
|
||||
}
|
||||
filter.From = &t
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("to")); raw != "" {
|
||||
t, dateOnly, err := parseContentModerationDate(raw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid to")
|
||||
return
|
||||
}
|
||||
if dateOnly {
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
}
|
||||
filter.To = &t
|
||||
}
|
||||
items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) UnbanUser(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
result, err := h.service.UnbanUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) {
|
||||
var req contentModerationHashRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) {
|
||||
result, err := h.service.ClearFlaggedInputHashes(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func parseContentModerationDate(raw string) (time.Time, bool, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return t, false, nil
|
||||
}
|
||||
t, err := time.Parse("2006-01-02", raw)
|
||||
return t, err == nil, err
|
||||
}
|
||||
@ -117,6 +117,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||
LoginAgreementMode: settings.LoginAgreementMode,
|
||||
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@ -169,6 +173,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: settings.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured,
|
||||
GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: settings.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured,
|
||||
GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
@ -185,6 +199,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
RiskControlEnabled: settings.RiskControlEnabled,
|
||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
@ -294,17 +309,50 @@ func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.O
|
||||
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
func loginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, dto.LoginAgreementDocument{
|
||||
ID: item.ID,
|
||||
Title: item.Title,
|
||||
ContentMD: item.ContentMD,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func loginAgreementDocumentsToService(items []dto.LoginAgreementDocument) []service.LoginAgreementDocument {
|
||||
result := make([]service.LoginAgreementDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
title := strings.TrimSpace(item.Title)
|
||||
content := strings.TrimSpace(item.ContentMD)
|
||||
if title == "" && content == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.LoginAgreementDocument{
|
||||
ID: strings.TrimSpace(item.ID),
|
||||
Title: title,
|
||||
ContentMD: content,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||
LoginAgreementDocuments []dto.LoginAgreementDocument `json:"login_agreement_documents"`
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@ -368,6 +416,17 @@ type UpdateSettingsRequest struct {
|
||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||
GitHubOAuthClientSecret string `json:"github_oauth_client_secret"`
|
||||
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||
GoogleOAuthClientSecret string `json:"google_oauth_client_secret"`
|
||||
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@ -413,6 +472,16 @@ type UpdateSettingsRequest struct {
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
|
||||
// Model fallback configuration
|
||||
@ -497,6 +566,9 @@ type UpdateSettingsRequest struct {
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
|
||||
// 风控中心功能开关
|
||||
RiskControlEnabled *bool `json:"risk_control_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
@ -633,6 +705,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
loginAgreementMode := strings.ToLower(strings.TrimSpace(req.LoginAgreementMode))
|
||||
if loginAgreementMode == "" {
|
||||
loginAgreementMode = strings.ToLower(strings.TrimSpace(previousSettings.LoginAgreementMode))
|
||||
}
|
||||
switch loginAgreementMode {
|
||||
case "", "modal":
|
||||
loginAgreementMode = "modal"
|
||||
case "checkbox":
|
||||
default:
|
||||
response.BadRequest(c, "Login agreement mode must be modal or checkbox")
|
||||
return
|
||||
}
|
||||
loginAgreementUpdatedAt := strings.TrimSpace(req.LoginAgreementUpdatedAt)
|
||||
if loginAgreementUpdatedAt == "" {
|
||||
loginAgreementUpdatedAt = strings.TrimSpace(previousSettings.LoginAgreementUpdatedAt)
|
||||
}
|
||||
loginAgreementDocuments := loginAgreementDocumentsToService(req.LoginAgreementDocuments)
|
||||
if len(loginAgreementDocuments) == 0 {
|
||||
loginAgreementDocuments = previousSettings.LoginAgreementDocuments
|
||||
}
|
||||
for _, doc := range loginAgreementDocuments {
|
||||
if strings.TrimSpace(doc.Title) == "" {
|
||||
response.BadRequest(c, "Login agreement document title is required")
|
||||
return
|
||||
}
|
||||
if len(doc.Title) > 80 {
|
||||
response.BadRequest(c, "Login agreement document title is too long (max 80 characters)")
|
||||
return
|
||||
}
|
||||
if len(doc.ContentMD) > 200*1024 {
|
||||
response.BadRequest(c, "Login agreement document content is too large (max 200KB)")
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.LoginAgreementEnabled && len(loginAgreementDocuments) == 0 {
|
||||
response.BadRequest(c, "Login agreement documents are required when enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
@ -994,17 +1104,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(item.URL) == "" {
|
||||
response.BadRequest(c, "Custom menu item URL is required")
|
||||
return
|
||||
}
|
||||
if len(item.URL) > maxMenuItemURLLen {
|
||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
|
||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
|
||||
return
|
||||
urlTrimmed := strings.TrimSpace(item.URL)
|
||||
if strings.HasPrefix(urlTrimmed, "md:") {
|
||||
// Markdown page mode: URL = "md:<slug>"
|
||||
slug := strings.TrimPrefix(urlTrimmed, "md:")
|
||||
if slug == "" {
|
||||
response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if urlTrimmed == "" {
|
||||
response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)")
|
||||
return
|
||||
}
|
||||
if len(item.URL) > maxMenuItemURLLen {
|
||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil {
|
||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:<slug>")
|
||||
return
|
||||
}
|
||||
}
|
||||
if item.Visibility != "user" && item.Visibility != "admin" {
|
||||
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
|
||||
@ -1148,6 +1268,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
@ -1200,6 +1324,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
@ -1365,6 +1499,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.AffiliateEnabled
|
||||
}(),
|
||||
RiskControlEnabled: func() bool {
|
||||
if req.RiskControlEnabled != nil {
|
||||
return *req.RiskControlEnabled
|
||||
}
|
||||
return previousSettings.RiskControlEnabled
|
||||
}(),
|
||||
}
|
||||
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
@ -1396,6 +1536,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||
},
|
||||
GitHub: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
|
||||
},
|
||||
Google: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||
@ -1486,6 +1640,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
LoginAgreementEnabled: updatedSettings.LoginAgreementEnabled,
|
||||
LoginAgreementMode: updatedSettings.LoginAgreementMode,
|
||||
LoginAgreementUpdatedAt: updatedSettings.LoginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocumentsToDTO(updatedSettings.LoginAgreementDocuments),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@ -1538,6 +1696,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured,
|
||||
GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured,
|
||||
GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
@ -1616,6 +1784,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
|
||||
RiskControlEnabled: updatedSettings.RiskControlEnabled,
|
||||
}
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
@ -1685,6 +1855,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.LoginAgreementEnabled != after.LoginAgreementEnabled {
|
||||
changed = append(changed, "login_agreement_enabled")
|
||||
}
|
||||
if before.LoginAgreementMode != after.LoginAgreementMode {
|
||||
changed = append(changed, "login_agreement_mode")
|
||||
}
|
||||
if before.LoginAgreementUpdatedAt != after.LoginAgreementUpdatedAt {
|
||||
changed = append(changed, "login_agreement_updated_at")
|
||||
}
|
||||
if !equalLoginAgreementDocuments(before.LoginAgreementDocuments, after.LoginAgreementDocuments) {
|
||||
changed = append(changed, "login_agreement_documents")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
@ -2004,6 +2186,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AffiliateEnabled != after.AffiliateEnabled {
|
||||
changed = append(changed, "affiliate_enabled")
|
||||
}
|
||||
if before.RiskControlEnabled != after.RiskControlEnabled {
|
||||
changed = append(changed, "risk_control_enabled")
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
@ -2027,6 +2212,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
|
||||
{name: "oidc", before: before.OIDC, after: after.OIDC},
|
||||
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||
{name: "github", before: before.GitHub, after: after.GitHub},
|
||||
{name: "google", before: before.Google, after: after.Google},
|
||||
}
|
||||
for _, field := range fields {
|
||||
if field.before.Balance != field.after.Balance {
|
||||
@ -2141,6 +2328,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
|
||||
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
|
||||
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
|
||||
data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance
|
||||
data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency
|
||||
data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions
|
||||
data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup
|
||||
data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind
|
||||
data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance
|
||||
data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency
|
||||
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
|
||||
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
|
||||
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
|
||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||
|
||||
return data
|
||||
@ -2170,6 +2367,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func equalLoginAgreementDocuments(a, b []service.LoginAgreementDocument) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].ID != b[i].ID || a[i].Title != b[i].Title || a[i].ContentMD != b[i].ContentMD {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalIntSlice(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
|
||||
@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
||||
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
// BatchUpdateConcurrency 批量修改用户并发数
|
||||
// POST /api/v1/admin/users/batch-concurrency
|
||||
type BatchUpdateConcurrencyRequest struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
All bool `json:"all"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Mode string `json:"mode" binding:"required,oneof=set add"`
|
||||
}
|
||||
|
||||
func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
|
||||
var req BatchUpdateConcurrencyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.All && len(req.UserIDs) == 0 {
|
||||
response.BadRequest(c, "user_ids is required unless all=true")
|
||||
return
|
||||
}
|
||||
if len(req.UserIDs) > 500 {
|
||||
response.BadRequest(c, "user_ids cannot exceed 500")
|
||||
return
|
||||
}
|
||||
|
||||
var userIDs []int64
|
||||
if req.All {
|
||||
// Fetch all user IDs via pagination
|
||||
page := 1
|
||||
const pageSize = 500
|
||||
for {
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
for _, u := range users {
|
||||
userIDs = append(userIDs, u.ID)
|
||||
}
|
||||
if len(users) < pageSize {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
} else {
|
||||
userIDs = req.UserIDs
|
||||
}
|
||||
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, gin.H{"affected": 0})
|
||||
return
|
||||
}
|
||||
|
||||
affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"affected": affected})
|
||||
}
|
||||
|
||||
621
backend/internal/handler/auth_email_oauth.go
Normal file
621
backend/internal/handler/auth_email_oauth.go
Normal file
@ -0,0 +1,621 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
emailOAuthCookiePath = "/api/v1/auth/oauth"
|
||||
emailOAuthStateCookieName = "email_oauth_state"
|
||||
emailOAuthRedirectCookie = "email_oauth_redirect"
|
||||
emailOAuthProviderCookie = "email_oauth_provider"
|
||||
emailOAuthAffiliateCookie = "email_oauth_affiliate"
|
||||
emailOAuthCookieMaxAgeSec = 10 * 60
|
||||
emailOAuthDefaultRedirect = "/dashboard"
|
||||
)
|
||||
|
||||
type emailOAuthTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type emailOAuthProfile struct {
|
||||
Subject string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Username string
|
||||
DisplayName string
|
||||
AvatarURL string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") }
|
||||
func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") }
|
||||
|
||||
func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") }
|
||||
func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") }
|
||||
func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) {
|
||||
h.completeEmailOAuthRegistration(c, "github")
|
||||
}
|
||||
func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) {
|
||||
h.completeEmailOAuthRegistration(c, "google")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) {
|
||||
cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = emailOAuthDefaultRedirect
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie)
|
||||
emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie)
|
||||
emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie)
|
||||
if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" {
|
||||
emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie)
|
||||
} else {
|
||||
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||
}
|
||||
|
||||
authURL, err := buildEmailOAuthAuthorizeURL(cfg, state)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) {
|
||||
cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||
if cfgErr != nil {
|
||||
response.ErrorFrom(c, cfgErr)
|
||||
return
|
||||
}
|
||||
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||
if frontendCallback == "" {
|
||||
frontendCallback = "/auth/oauth/callback"
|
||||
}
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie)
|
||||
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||
}()
|
||||
expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || expectedState != state {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie)
|
||||
if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "")
|
||||
return
|
||||
}
|
||||
redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = emailOAuthDefaultRedirect
|
||||
}
|
||||
|
||||
tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error()))
|
||||
return
|
||||
}
|
||||
profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error()))
|
||||
return
|
||||
}
|
||||
h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthCallbackWithProfile(
|
||||
c *gin.Context,
|
||||
provider string,
|
||||
cfg config.EmailOAuthProviderConfig,
|
||||
frontendCallback string,
|
||||
redirectTo string,
|
||||
profile *emailOAuthProfile,
|
||||
) {
|
||||
input := service.EmailOAuthIdentityInput{
|
||||
ProviderType: provider,
|
||||
ProviderKey: provider,
|
||||
ProviderSubject: profile.Subject,
|
||||
Email: profile.Email,
|
||||
EmailVerified: profile.EmailVerified,
|
||||
Username: profile.Username,
|
||||
DisplayName: profile.DisplayName,
|
||||
AvatarURL: profile.AvatarURL,
|
||||
UpstreamMetadata: profile.Metadata,
|
||||
}
|
||||
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||
return
|
||||
} else if shouldCreate {
|
||||
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", tokenPair.AccessToken)
|
||||
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
|
||||
ProviderType: strings.TrimSpace(input.ProviderType),
|
||||
ProviderKey: strings.TrimSpace(input.ProviderKey),
|
||||
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||
if identityUser != nil {
|
||||
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
|
||||
if errors.Is(err, service.ErrUserNotFound) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil {
|
||||
return strings.TrimSpace(code)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
|
||||
c *gin.Context,
|
||||
provider string,
|
||||
frontendCallback string,
|
||||
redirectTo string,
|
||||
profile *emailOAuthProfile,
|
||||
) error {
|
||||
if h == nil || profile == nil {
|
||||
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
browserSessionKey, err := generateOAuthPendingBrowserSession()
|
||||
if err != nil {
|
||||
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
|
||||
}
|
||||
setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c))
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
username := strings.TrimSpace(profile.Username)
|
||||
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"email_verified": profile.EmailVerified,
|
||||
"username": username,
|
||||
"provider": provider,
|
||||
"provider_key": provider,
|
||||
"provider_subject": strings.TrimSpace(profile.Subject),
|
||||
}
|
||||
if strings.TrimSpace(profile.DisplayName) != "" {
|
||||
upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName)
|
||||
}
|
||||
if strings.TrimSpace(profile.AvatarURL) != "" {
|
||||
upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL)
|
||||
}
|
||||
if affiliateCode != "" {
|
||||
upstreamClaims["aff_code"] = affiliateCode
|
||||
}
|
||||
for key, value := range profile.Metadata {
|
||||
if _, exists := upstreamClaims[key]; !exists {
|
||||
upstreamClaims[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
|
||||
pendingError := "registration_completion_required"
|
||||
choiceReason := "registration_completion_required"
|
||||
if invitationRequired {
|
||||
pendingError = "invitation_required"
|
||||
choiceReason = "invitation_required"
|
||||
}
|
||||
completionResponse := map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"error": pendingError,
|
||||
"choice_reason": choiceReason,
|
||||
"adoption_required": false,
|
||||
"create_account_allowed": true,
|
||||
"existing_account_bindable": false,
|
||||
"force_email_on_signup": true,
|
||||
"invitation_required": invitationRequired,
|
||||
"email": email,
|
||||
"resolved_email": email,
|
||||
"provider": provider,
|
||||
"redirect": redirectTo,
|
||||
}
|
||||
if strings.TrimSpace(frontendCallback) != "" {
|
||||
completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback)
|
||||
}
|
||||
|
||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)},
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: completionResponse,
|
||||
})
|
||||
}
|
||||
|
||||
type completeEmailOAuthRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) {
|
||||
var req completeEmailOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
affiliateCode := strings.TrimSpace(req.AffCode)
|
||||
if affiliateCode == "" {
|
||||
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
strings.TrimSpace(session.ResolvedEmail),
|
||||
req.Password,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||
return
|
||||
}
|
||||
tx, err := client.Tx(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
|
||||
sessionForBinding := *session
|
||||
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
|
||||
if strings.TrimSpace(req.InvitationCode) != "" {
|
||||
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
|
||||
}
|
||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
|
||||
_ = tx.Rollback()
|
||||
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||
respondPendingOAuthBindingApplyError(c, err)
|
||||
return
|
||||
}
|
||||
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||
txCtx,
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
affiliateCode,
|
||||
); err != nil {
|
||||
_ = tx.Rollback()
|
||||
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
|
||||
_ = tx.Rollback()
|
||||
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider)
|
||||
}
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
|
||||
func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) {
|
||||
u, err := url.Parse(cfg.AuthorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", cfg.ClientID)
|
||||
q.Set("redirect_uri", cfg.RedirectURL)
|
||||
q.Set("state", state)
|
||||
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetFormData(map[string]string{
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": cfg.ClientID,
|
||||
"client_secret": cfg.ClientSecret,
|
||||
"code": code,
|
||||
"redirect_uri": cfg.RedirectURL,
|
||||
}).
|
||||
Post(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
var tokenResp emailOAuthTokenResponse
|
||||
if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, errors.New("missing access_token")
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetBearerAuthToken(token.AccessToken).
|
||||
SetHeader("Accept", "application/json").
|
||||
Get(cfg.UserInfoURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
return parseGitHubOAuthProfile(ctx, cfg, token, resp.String())
|
||||
case "google":
|
||||
return parseGoogleOAuthProfile(resp.String())
|
||||
default:
|
||||
return nil, errors.New("unsupported oauth provider")
|
||||
}
|
||||
}
|
||||
|
||||
func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) {
|
||||
subject := strings.TrimSpace(gjson.Get(body, "id").String())
|
||||
if subject == "" {
|
||||
return nil, errors.New("github user id is missing")
|
||||
}
|
||||
email := ""
|
||||
emailsURL := strings.TrimSpace(cfg.EmailsURL)
|
||||
if emailsURL == "" {
|
||||
return nil, errors.New("github verified email is missing")
|
||||
}
|
||||
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
email = verifiedEmail
|
||||
if email == "" {
|
||||
return nil, errors.New("github verified email is missing")
|
||||
}
|
||||
login := strings.TrimSpace(gjson.Get(body, "login").String())
|
||||
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||
return &emailOAuthProfile{
|
||||
Subject: subject,
|
||||
Email: email,
|
||||
EmailVerified: true,
|
||||
Username: firstNonEmpty(login, name, "github_"+subject),
|
||||
DisplayName: firstNonEmpty(name, login),
|
||||
AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()),
|
||||
Metadata: map[string]any{
|
||||
"login": login,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) {
|
||||
resp, err := req.C().
|
||||
R().
|
||||
SetContext(ctx).
|
||||
SetBearerAuthToken(accessToken).
|
||||
SetHeader("Accept", "application/json").
|
||||
Get(emailsURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||
}
|
||||
items := gjson.Parse(resp.String()).Array()
|
||||
for _, item := range items {
|
||||
if item.Get("primary").Bool() && item.Get("verified").Bool() {
|
||||
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Get("verified").Bool() {
|
||||
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("github verified email is missing")
|
||||
}
|
||||
|
||||
func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) {
|
||||
subject := strings.TrimSpace(gjson.Get(body, "sub").String())
|
||||
email := strings.TrimSpace(gjson.Get(body, "email").String())
|
||||
verified := gjson.Get(body, "email_verified").Bool()
|
||||
if subject == "" {
|
||||
return nil, errors.New("google subject is missing")
|
||||
}
|
||||
if email == "" || !verified {
|
||||
return nil, errors.New("google verified email is missing")
|
||||
}
|
||||
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||
return &emailOAuthProfile{
|
||||
Subject: subject,
|
||||
Email: email,
|
||||
EmailVerified: true,
|
||||
Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email),
|
||||
DisplayName: name,
|
||||
AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()),
|
||||
Metadata: map[string]any{
|
||||
"email_verified": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: emailOAuthCookiePath,
|
||||
MaxAge: emailOAuthCookieMaxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func emailOAuthClearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: emailOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
414
backend/internal/handler/auth_email_oauth_test.go
Normal file
414
backend/internal/handler/auth_email_oauth_test.go
Normal file
@ -0,0 +1,414 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
state := "github-oauth-state"
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil)
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)})
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")})
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")})
|
||||
c.Request = req
|
||||
|
||||
profile := &emailOAuthProfile{
|
||||
Subject: "github-123",
|
||||
Email: "fresh@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "fresh",
|
||||
DisplayName: "Fresh User",
|
||||
AvatarURL: "https://cdn.example/fresh.png",
|
||||
Metadata: map[string]any{
|
||||
"login": "fresh",
|
||||
},
|
||||
}
|
||||
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "github-client",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", profile)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "/auth/oauth/callback")
|
||||
require.NotContains(t, location, "access_token=")
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "github", session.ProviderType)
|
||||
require.Equal(t, "github", session.ProviderKey)
|
||||
require.Equal(t, "github-123", session.ProviderSubject)
|
||||
require.Equal(t, "fresh@example.com", session.ResolvedEmail)
|
||||
require.Equal(t, "/dashboard", session.RedirectTo)
|
||||
require.Nil(t, session.TargetUserID)
|
||||
|
||||
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||
require.Equal(t, "invitation_required", completion["error"])
|
||||
require.Equal(t, true, completion["invitation_required"])
|
||||
require.Equal(t, "fresh@example.com", completion["email"])
|
||||
require.Equal(t, "fresh@example.com", completion["resolved_email"])
|
||||
require.Equal(t, true, completion["create_account_allowed"])
|
||||
|
||||
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName))
|
||||
}
|
||||
|
||||
func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("existing@example.com").
|
||||
SetUsername("existing").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil)
|
||||
|
||||
handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "google-client",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||
Subject: "google-123",
|
||||
Email: "existing@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "existing",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "access_token=")
|
||||
require.Contains(t, location, "redirect=%252Fdashboard")
|
||||
|
||||
sessionCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, sessionCount)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().Where(
|
||||
authidentity.ProviderTypeEQ("google"),
|
||||
authidentity.ProviderSubjectEQ("google-123"),
|
||||
).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, identityCount)
|
||||
_ = user
|
||||
}
|
||||
|
||||
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
|
||||
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
settingValues: map[string]string{
|
||||
service.SettingKeyAffiliateEnabled: "true",
|
||||
},
|
||||
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||
},
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil)
|
||||
req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")})
|
||||
c.Request = req
|
||||
|
||||
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||
Enabled: true,
|
||||
ClientID: "github-client",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||
FrontendRedirectURL: "/auth/oauth/callback",
|
||||
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||
Subject: "github-aff-user",
|
||||
Email: "aff-user@example.com",
|
||||
EmailVerified: true,
|
||||
Username: "aff-user",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
require.Empty(t, affiliateRepo.ensureUserIDs)
|
||||
require.Empty(t, affiliateRepo.bindCalls)
|
||||
|
||||
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
|
||||
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
|
||||
|
||||
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||
require.Equal(t, "registration_completion_required", completion["error"])
|
||||
require.Equal(t, false, completion["invitation_required"])
|
||||
require.Equal(t, true, completion["create_account_allowed"])
|
||||
require.Equal(t, true, completion["force_email_on_signup"])
|
||||
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
|
||||
}
|
||||
|
||||
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
|
||||
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002})
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
invitationEnabled: true,
|
||||
settingValues: map[string]string{
|
||||
service.SettingKeyAffiliateEnabled: "true",
|
||||
},
|
||||
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||
},
|
||||
})
|
||||
ctx := context.Background()
|
||||
invitation, err := client.RedeemCode.Create().
|
||||
SetCode("INVITE456").
|
||||
SetType(service.RedeemTypeInvitation).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("email-oauth-aff-session-token").
|
||||
SetIntent(oauthIntentLogin).
|
||||
SetProviderType("google").
|
||||
SetProviderKey("google").
|
||||
SetProviderSubject("google-aff-user").
|
||||
SetResolvedEmail("pending-aff@example.com").
|
||||
SetRedirectTo("/dashboard").
|
||||
SetBrowserSessionKey("browser-aff-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"email": "pending-aff@example.com",
|
||||
"email_verified": true,
|
||||
"username": "pending-aff",
|
||||
"provider": "google",
|
||||
"provider_key": "google",
|
||||
"provider_subject": "google-aff-user",
|
||||
"aff_code": "AFF456",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"error": "invitation_required",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
|
||||
c.Request = req
|
||||
|
||||
handler.completeEmailOAuthRegistration(c, "google")
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, user.PasswordHash)
|
||||
require.NotEqual(t, "secret-123", user.PasswordHash)
|
||||
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, tamperedCount)
|
||||
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
|
||||
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedInvitation.UsedBy)
|
||||
require.Equal(t, user.ID, *storedInvitation.UsedBy)
|
||||
}
|
||||
|
||||
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("email-oauth-password-session-token").
|
||||
SetIntent(oauthIntentLogin).
|
||||
SetProviderType("github").
|
||||
SetProviderKey("github").
|
||||
SetProviderSubject("github-password-user").
|
||||
SetResolvedEmail("password-required@example.com").
|
||||
SetRedirectTo("/dashboard").
|
||||
SetBrowserSessionKey("browser-password-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"email": "password-required@example.com",
|
||||
"email_verified": true,
|
||||
"username": "password-required",
|
||||
"provider": "github",
|
||||
"provider_key": "github",
|
||||
"provider_subject": "github-password-user",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"error": "registration_completion_required",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
|
||||
c.Request = req
|
||||
|
||||
handler.completeEmailOAuthRegistration(c, "github")
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
}
|
||||
|
||||
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
|
||||
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "missing scope", http.StatusForbidden)
|
||||
}))
|
||||
t.Cleanup(emailServer.Close)
|
||||
|
||||
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
|
||||
EmailsURL: emailServer.URL,
|
||||
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, profile)
|
||||
require.Contains(t, err.Error(), "github emails endpoint status 403")
|
||||
}
|
||||
|
||||
type oauthEmailAffiliateBindCall struct {
|
||||
userID int64
|
||||
inviterID int64
|
||||
}
|
||||
|
||||
type oauthEmailAffiliateRepoStub struct {
|
||||
codeOwners map[string]int64
|
||||
ensureUserIDs []int64
|
||||
bindCalls []oauthEmailAffiliateBindCall
|
||||
}
|
||||
|
||||
func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub {
|
||||
return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners}
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) {
|
||||
r.ensureUserIDs = append(r.ensureUserIDs, userID)
|
||||
return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) {
|
||||
userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))]
|
||||
if !ok {
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) {
|
||||
r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID})
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) {
|
||||
panic("unexpected AccrueQuota call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
|
||||
panic("unexpected GetAccruedRebateFromInvitee call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) {
|
||||
panic("unexpected ThawFrozenQuota call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) {
|
||||
panic("unexpected TransferQuotaToBalance call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) {
|
||||
panic("unexpected ListInvitees call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error {
|
||||
panic("unexpected UpdateUserAffCode call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) {
|
||||
panic("unexpected ResetUserAffCode call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error {
|
||||
panic("unexpected SetUserRebateRate call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error {
|
||||
panic("unexpected BatchSetUserRebateRate call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
|
||||
panic("unexpected ListUsersWithCustomSettings call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateInviteRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateRebateRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||
panic("unexpected ListAffiliateTransferRecords call")
|
||||
}
|
||||
|
||||
func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) {
|
||||
panic("unexpected GetAffiliateUserOverview call")
|
||||
}
|
||||
|
||||
func findSetCookieValue(cookies []*http.Cookie, name string) string {
|
||||
for _, cookie := range cookies {
|
||||
if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 {
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct {
|
||||
emailCache service.EmailCache
|
||||
settingValues map[string]string
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||
affiliateService *service.AffiliateService
|
||||
affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService
|
||||
totpCache service.TotpCache
|
||||
totpEncryptor service.SecretEncryptor
|
||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||
@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
aff_code TEXT NOT NULL UNIQUE,
|
||||
aff_code_custom BOOLEAN NOT NULL DEFAULT false,
|
||||
aff_rebate_rate_percent REAL NULL,
|
||||
inviter_id INTEGER NULL,
|
||||
aff_count INTEGER NOT NULL DEFAULT 0,
|
||||
aff_quota REAL NOT NULL DEFAULT 0,
|
||||
aff_frozen_quota REAL NOT NULL DEFAULT 0,
|
||||
aff_history_quota REAL NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
},
|
||||
}
|
||||
settingValues := map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
}
|
||||
for key, value := range options.settingValues {
|
||||
settingValues[key] = value
|
||||
}
|
||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||
affiliateService := options.affiliateService
|
||||
if affiliateService == nil && options.affiliateFactory != nil {
|
||||
affiliateService = options.affiliateFactory(client, settingSvc)
|
||||
}
|
||||
userRepo := &oauthPendingFlowUserRepo{
|
||||
client: client,
|
||||
options: options.userRepoOptions,
|
||||
@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
nil,
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
nil,
|
||||
affiliateService,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
@ -2798,6 +2820,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||
panic("unexpected BatchSetConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||
panic("unexpected BatchAddConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||
return map[int64]*time.Time{}, nil
|
||||
}
|
||||
|
||||
130
backend/internal/handler/content_moderation_helper.go
Normal file
130
backend/internal/handler/content_moderation_helper.go
Normal file
@ -0,0 +1,130 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if h == nil || h.contentModerationService == nil {
|
||||
return nil
|
||||
}
|
||||
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||
}
|
||||
|
||||
func contentModerationStatus(decision *service.ContentModerationDecision) int {
|
||||
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
return decision.StatusCode
|
||||
}
|
||||
|
||||
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
|
||||
return "content_policy_violation"
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if h == nil || h.contentModerationService == nil {
|
||||
return nil
|
||||
}
|
||||
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||
}
|
||||
|
||||
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if svc == nil || c == nil || c.Request == nil {
|
||||
return nil
|
||||
}
|
||||
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
|
||||
if reqLog != nil {
|
||||
reqLog.Info("content_moderation.gateway_check_start",
|
||||
zap.String("request_id", input.RequestID),
|
||||
zap.Int64("user_id", input.UserID),
|
||||
zap.Int64("api_key_id", input.APIKeyID),
|
||||
zap.String("api_key_name", input.APIKeyName),
|
||||
zap.Int64p("group_id", input.GroupID),
|
||||
zap.String("group_name", input.GroupName),
|
||||
zap.String("endpoint", input.Endpoint),
|
||||
zap.String("provider", input.Provider),
|
||||
zap.String("protocol", input.Protocol),
|
||||
zap.String("model", input.Model),
|
||||
zap.Int("body_bytes", len(body)),
|
||||
)
|
||||
}
|
||||
decision, err := svc.Check(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
if reqLog != nil {
|
||||
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if reqLog != nil && decision != nil {
|
||||
reqLog.Info("content_moderation.gateway_check_done",
|
||||
zap.String("request_id", input.RequestID),
|
||||
zap.Bool("allowed", decision.Allowed),
|
||||
zap.Bool("blocked", decision.Blocked),
|
||||
zap.Bool("flagged", decision.Flagged),
|
||||
zap.String("action", decision.Action),
|
||||
zap.Int("status_code", decision.StatusCode),
|
||||
zap.String("highest_category", decision.HighestCategory),
|
||||
zap.Float64("highest_score", decision.HighestScore),
|
||||
)
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
|
||||
input := service.ContentModerationCheckInput{
|
||||
RequestID: contentModerationRequestID(c.Request.Context()),
|
||||
UserID: subject.UserID,
|
||||
Endpoint: GetInboundEndpoint(c),
|
||||
Provider: contentModerationProvider(apiKey),
|
||||
Model: strings.TrimSpace(model),
|
||||
Protocol: protocol,
|
||||
Body: body,
|
||||
}
|
||||
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
input.Provider = strings.TrimSpace(forcedPlatform)
|
||||
}
|
||||
if apiKey != nil {
|
||||
input.APIKeyID = apiKey.ID
|
||||
input.APIKeyName = apiKey.Name
|
||||
if apiKey.User != nil {
|
||||
input.UserEmail = apiKey.User.Email
|
||||
}
|
||||
if apiKey.GroupID != nil {
|
||||
groupID := *apiKey.GroupID
|
||||
input.GroupID = &groupID
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
input.GroupName = apiKey.Group.Name
|
||||
}
|
||||
}
|
||||
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
|
||||
input.Endpoint = c.Request.URL.Path
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func contentModerationProvider(apiKey *service.APIKey) string {
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.Platform)
|
||||
}
|
||||
|
||||
func contentModerationRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
|
||||
return strings.TrimSpace(requestID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -11,6 +11,7 @@ type CustomMenuItem struct {
|
||||
Label string `json:"label"`
|
||||
IconSVG string `json:"icon_svg"`
|
||||
URL string `json:"url"`
|
||||
PageSlug string `json:"page_slug,omitempty"`
|
||||
Visibility string `json:"visibility"` // "user" or "admin"
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
@ -24,15 +25,19 @@ type CustomEndpoint struct {
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@ -91,6 +96,17 @@ type SystemSettings struct {
|
||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||
GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"`
|
||||
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||
GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"`
|
||||
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
@ -197,6 +213,9 @@ type SystemSettings struct {
|
||||
// Available Channels feature switch (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
// 风控中心功能开关
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
@ -210,45 +229,52 @@ type DefaultSubscriptionSetting struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||
LoginAgreementRevision string `json:"login_agreement_revision"`
|
||||
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
@ -256,6 +282,14 @@ type PublicSettings struct {
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
}
|
||||
|
||||
type LoginAgreementDocument struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
ContentMD string `json:"content_md"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
|
||||
@ -46,6 +46,7 @@ type GatewayHandler struct {
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
contentModerationService *service.ContentModerationService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
userMsgQueueHelper *UserMsgQueueHelper
|
||||
requestEventBus *service.RequestEventBus
|
||||
@ -68,6 +69,7 @@ func NewGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
contentModerationService *service.ContentModerationService,
|
||||
userMsgQueueService *service.UserMessageQueueService,
|
||||
cfg *config.Config,
|
||||
settingService *service.SettingService,
|
||||
@ -103,6 +105,7 @@ func NewGatewayHandler(
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
contentModerationService: contentModerationService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
userMsgQueueHelper: umqHelper,
|
||||
requestEventBus: requestEventBus,
|
||||
@ -215,6 +218,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
|
||||
@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
|
||||
googleError(c, contentModerationStatus(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
reqModel := modelName // 保存映射前的原始模型名
|
||||
|
||||
@ -33,6 +33,7 @@ type AdminHandlers struct {
|
||||
Channel *admin.ChannelHandler
|
||||
ChannelMonitor *admin.ChannelMonitorHandler
|
||||
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
||||
ContentModeration *admin.ContentModerationHandler
|
||||
Payment *admin.PaymentHandler
|
||||
Windsurf *admin.WindsurfHandler
|
||||
Affiliate *admin.AffiliateHandler
|
||||
|
||||
@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
|
||||
@ -27,15 +27,16 @@ import (
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
contentModerationService *service.ContentModerationService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||
@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
contentModerationService *service.ContentModerationService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
contentModerationService: contentModerationService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||
if conn == nil || decision == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
message := strings.TrimSpace(decision.Message)
|
||||
if message == "" {
|
||||
message = "content moderation blocked this request"
|
||||
}
|
||||
payload, err := json.Marshal(gin.H{
|
||||
"event_id": "evt_content_moderation_blocked",
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"code": contentModerationErrorCode(decision),
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, payload)
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
type contentModerationHandlerSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return &service.Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
r.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := map[string]string{}
|
||||
for _, key := range keys {
|
||||
if value, ok := r.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(r.values))
|
||||
for key, value := range r.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(r.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
type contentModerationHandlerTestRepo struct {
|
||||
logs []service.ContentModerationLog
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log != nil {
|
||||
r.logs = append(r.logs, *log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||
return &service.ContentModerationCleanupResult{}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
|
||||
}))
|
||||
defer moderationServer.Close()
|
||||
|
||||
cfg := &service.ContentModerationConfig{
|
||||
Enabled: true,
|
||||
Mode: service.ContentModerationModePreBlock,
|
||||
BaseURL: moderationServer.URL,
|
||||
Model: "omni-moderation-latest",
|
||||
APIKeys: []string{"sk-test"},
|
||||
SampleRate: 100,
|
||||
AllGroups: true,
|
||||
BlockMessage: "内容审计测试阻断",
|
||||
}
|
||||
rawCfg, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := &contentModerationHandlerTestRepo{}
|
||||
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
|
||||
service.SettingKeyRiskControlEnabled: "true",
|
||||
service.SettingKeyContentModerationConfig: string(rawCfg),
|
||||
}}
|
||||
moderationSvc := service.NewContentModerationService(
|
||||
settingRepo,
|
||||
repo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
|
||||
UserID: 1,
|
||||
Endpoint: "/v1/responses",
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.5",
|
||||
Protocol: service.ContentModerationProtocolOpenAIResponses,
|
||||
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
repo.logs = nil
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
contentModerationService: moderationSvc,
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.5",
|
||||
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
|
||||
}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, payload, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
require.Contains(t, string(payload), "content_policy_violation")
|
||||
require.Contains(t, string(payload), "内容审计测试阻断")
|
||||
} else {
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, readErr, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||
}
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.True(t, repo.logs[0].Flagged)
|
||||
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
|
||||
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
|
||||
@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !acquired {
|
||||
return
|
||||
|
||||
283
backend/internal/handler/page_handler.go
Normal file
283
backend/internal/handler/page_handler.go
Normal file
@ -0,0 +1,283 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
|
||||
|
||||
const maxPageFileSize = 1 << 20 // 1MB
|
||||
|
||||
type PageHandler struct {
|
||||
pagesDir string
|
||||
settingService *service.SettingService
|
||||
}
|
||||
|
||||
func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler {
|
||||
pagesDir := filepath.Join(dataDir, "pages")
|
||||
_ = os.MkdirAll(pagesDir, 0755)
|
||||
return &PageHandler{pagesDir: pagesDir, settingService: settingService}
|
||||
}
|
||||
|
||||
// GetPageContent serves raw markdown content for a given slug.
|
||||
// GET /api/v1/pages/:slug
|
||||
func (h *PageHandler) GetPageContent(c *gin.Context) {
|
||||
slug := c.Param("slug")
|
||||
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||
response.BadRequest(c, "Invalid page slug")
|
||||
return
|
||||
}
|
||||
|
||||
// Visibility check: slug must be configured in custom_menu_items
|
||||
// and the user must have permission based on visibility setting
|
||||
if !h.checkSlugVisibility(c, slug) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||
return
|
||||
}
|
||||
|
||||
filePath := filepath.Join(h.pagesDir, slug+".md")
|
||||
cleaned := filepath.Clean(filePath)
|
||||
if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) {
|
||||
response.BadRequest(c, "Invalid page slug")
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(cleaned)
|
||||
if err != nil || info.IsDir() {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||
return
|
||||
}
|
||||
if info.Size() > maxPageFileSize {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"})
|
||||
return
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(cleaned)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/markdown; charset=utf-8", content)
|
||||
}
|
||||
|
||||
// ListPages returns available page slugs.
|
||||
// GET /api/v1/pages
|
||||
func (h *PageHandler) ListPages(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.pagesDir)
|
||||
if err != nil {
|
||||
response.Success(c, []string{})
|
||||
return
|
||||
}
|
||||
|
||||
slugs := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if strings.HasSuffix(name, ".md") {
|
||||
slugs = append(slugs, strings.TrimSuffix(name, ".md"))
|
||||
}
|
||||
}
|
||||
response.Success(c, slugs)
|
||||
}
|
||||
|
||||
// ServePageImage serves images from data/pages/{slug}/ directory.
|
||||
// GET /api/v1/pages/:slug/images/*filename
|
||||
// No JWT required (browser img tags can't carry tokens), but visibility is checked.
|
||||
func (h *PageHandler) ServePageImage(c *gin.Context) {
|
||||
slug := c.Param("slug")
|
||||
filename := c.Param("filename")
|
||||
filename = strings.TrimPrefix(filename, "/")
|
||||
|
||||
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if !h.checkImageSlugVisibility(c, slug) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
imagesDir := filepath.Join(h.pagesDir, slug)
|
||||
cleaned, ok := resolvePageImagePath(h.pagesDir, imagesDir, filename)
|
||||
if !ok {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(cleaned)
|
||||
if err != nil || info.IsDir() {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(cleaned)
|
||||
}
|
||||
|
||||
func resolvePageImagePath(pagesDir, imagesDir, filename string) (string, bool) {
|
||||
relPath, ok := cleanPageImageRelativePath(filename)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cleanedPagesDir := filepath.Clean(pagesDir)
|
||||
cleanedImagesDir := filepath.Clean(imagesDir)
|
||||
cleanedTarget := filepath.Clean(filepath.Join(cleanedImagesDir, relPath))
|
||||
if !isPathWithinBase(cleanedTarget, cleanedImagesDir) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
realPagesDir, err := filepath.EvalSymlinks(cleanedPagesDir)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
realImagesDir, err := filepath.EvalSymlinks(cleanedImagesDir)
|
||||
if err != nil || !isPathWithinBase(realImagesDir, realPagesDir) {
|
||||
return "", false
|
||||
}
|
||||
realTarget, err := filepath.EvalSymlinks(cleanedTarget)
|
||||
if err != nil || !isPathWithinBase(realTarget, realImagesDir) {
|
||||
return "", false
|
||||
}
|
||||
return realTarget, true
|
||||
}
|
||||
|
||||
func cleanPageImageRelativePath(filename string) (string, bool) {
|
||||
if filename == "" {
|
||||
return "", false
|
||||
}
|
||||
if strings.HasPrefix(filename, "/") {
|
||||
return "", false
|
||||
}
|
||||
decoded, err := url.PathUnescape(filename)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if decoded == "" || strings.HasPrefix(decoded, "/") || strings.Contains(decoded, "\\") || strings.ContainsRune(decoded, 0) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := make([]string, 0)
|
||||
for _, part := range strings.Split(decoded, "/") {
|
||||
switch part {
|
||||
case "", ".":
|
||||
continue
|
||||
case "..":
|
||||
return "", false
|
||||
default:
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
relPath := filepath.Join(parts...)
|
||||
if filepath.IsAbs(relPath) || filepath.VolumeName(relPath) != "" {
|
||||
return "", false
|
||||
}
|
||||
return relPath, true
|
||||
}
|
||||
|
||||
func isPathWithinBase(path, base string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(base), filepath.Clean(path))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return rel != "." && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator))
|
||||
}
|
||||
|
||||
// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found).
|
||||
func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) {
|
||||
if h.settingService == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context())
|
||||
if raw == "" || raw == "[]" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var items []struct {
|
||||
URL string `json:"url"`
|
||||
PageSlug string `json:"page_slug"`
|
||||
Visibility string `json:"visibility"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
itemSlug := item.PageSlug
|
||||
if itemSlug == "" && strings.HasPrefix(item.URL, "md:") {
|
||||
itemSlug = strings.TrimPrefix(item.URL, "md:")
|
||||
}
|
||||
if itemSlug == slug {
|
||||
return item.Visibility, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// checkSlugVisibility verifies the slug is configured in custom_menu_items
|
||||
// and the authenticated user has permission to view it.
|
||||
func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool {
|
||||
visibility, found := h.findSlugVisibility(c, slug)
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
if visibility == "admin" {
|
||||
role, _ := middleware2.GetUserRoleFromContext(c)
|
||||
return role == "admin"
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkImageSlugVisibility checks visibility for image requests (no JWT available).
|
||||
// Only allows user-visible pages; admin-only pages are blocked.
|
||||
func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool {
|
||||
visibility, found := h.findSlugVisibility(c, slug)
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
return visibility != "admin"
|
||||
}
|
||||
|
||||
// RegisterPageRoutes registers page routes on a router group.
|
||||
func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) {
|
||||
h := NewPageHandler(dataDir, settingService)
|
||||
|
||||
// Authenticated page content (JWT required + visibility check)
|
||||
pages := v1.Group("/pages")
|
||||
pages.Use(jwtAuth)
|
||||
{
|
||||
pages.GET("/:slug", h.GetPageContent)
|
||||
}
|
||||
|
||||
// Images: no JWT (browser img tags can't carry tokens), visibility check in handler
|
||||
pageImages := v1.Group("/pages")
|
||||
{
|
||||
pageImages.GET("/:slug/images/*filename", h.ServePageImage)
|
||||
}
|
||||
|
||||
// Admin-only: list all available pages
|
||||
adminPages := v1.Group("/pages")
|
||||
adminPages.Use(adminAuth)
|
||||
{
|
||||
adminPages.GET("", h.ListPages)
|
||||
}
|
||||
}
|
||||
102
backend/internal/handler/page_handler_test.go
Normal file
102
backend/internal/handler/page_handler_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCleanPageImageRelativePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{name: "single filename", in: "logo.png", want: "logo.png", ok: true},
|
||||
{name: "nested path", in: "images/logo.png", want: filepath.Join("images", "logo.png"), ok: true},
|
||||
{name: "dot prefix", in: "./logo.png", want: "logo.png", ok: true},
|
||||
{name: "url escaped slash", in: "images%2Flogo.png", want: filepath.Join("images", "logo.png"), ok: true},
|
||||
{name: "parent traversal", in: "../secret.png", ok: false},
|
||||
{name: "encoded parent traversal", in: "%2e%2e/secret.png", ok: false},
|
||||
{name: "backslash traversal", in: `images\secret.png`, ok: false},
|
||||
{name: "absolute path", in: "/etc/passwd", ok: false},
|
||||
{name: "encoded absolute path", in: "%2fetc/passwd", ok: false},
|
||||
{name: "encoded nul byte", in: "logo.png%00", ok: false},
|
||||
{name: "invalid escape", in: "logo.png%zz", ok: false},
|
||||
{name: "empty path", in: "", ok: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := cleanPageImageRelativePath(tt.in)
|
||||
if ok != tt.ok {
|
||||
t.Fatalf("ok = %v, want %v", ok, tt.ok)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("path = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePageImagePath(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
pagesDir := filepath.Join(root, "pages")
|
||||
base := filepath.Join(pagesDir, "guide")
|
||||
if err := os.MkdirAll(filepath.Join(base, "images"), 0755); err != nil {
|
||||
t.Fatalf("create images dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(base, "logo.png"), []byte("fake"), 0644); err != nil {
|
||||
t.Fatalf("create direct image: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(base, "images", "logo.png"), []byte("fake"), 0644); err != nil {
|
||||
t.Fatalf("create image: %v", err)
|
||||
}
|
||||
|
||||
got, ok := resolvePageImagePath(pagesDir, base, "logo.png")
|
||||
if !ok {
|
||||
t.Fatal("expected direct image path to be accepted")
|
||||
}
|
||||
want := filepath.Join(base, "logo.png")
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
got, ok = resolvePageImagePath(pagesDir, base, "images/logo.png")
|
||||
if !ok {
|
||||
t.Fatal("expected nested image path to be accepted")
|
||||
}
|
||||
want = filepath.Join(base, "images", "logo.png")
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if got, ok := resolvePageImagePath(pagesDir, base, "../guide.md"); ok {
|
||||
t.Fatalf("expected traversal to be rejected, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
pagesDir := filepath.Join(root, "pages")
|
||||
base := filepath.Join(pagesDir, "guide")
|
||||
outside := filepath.Join(root, "outside")
|
||||
|
||||
if err := os.MkdirAll(base, 0755); err != nil {
|
||||
t.Fatalf("create page dir: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(outside, 0755); err != nil {
|
||||
t.Fatalf("create outside dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(outside, "secret.png"), []byte("secret"), 0644); err != nil {
|
||||
t.Fatalf("create outside file: %v", err)
|
||||
}
|
||||
if err := os.Symlink(outside, filepath.Join(base, "images")); err != nil {
|
||||
t.Skipf("symlink not supported: %v", err)
|
||||
}
|
||||
|
||||
if got, ok := resolvePageImagePath(pagesDir, base, "images/secret.png"); ok {
|
||||
t.Fatalf("expected symlink escape to be rejected, got %q", got)
|
||||
}
|
||||
}
|
||||
@ -40,6 +40,11 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||
LoginAgreementMode: settings.LoginAgreementMode,
|
||||
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||
LoginAgreementRevision: settings.LoginAgreementRevision,
|
||||
LoginAgreementDocuments: publicLoginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
@ -63,6 +68,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: h.version,
|
||||
@ -77,5 +84,19 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
|
||||
RiskControlEnabled: settings.RiskControlEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, dto.LoginAgreementDocument{
|
||||
ID: item.ID,
|
||||
Title: item.Title,
|
||||
ContentMD: item.ContentMD,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
|
||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
|
||||
@ -36,6 +36,7 @@ func ProvideAdminHandlers(
|
||||
channelHandler *admin.ChannelHandler,
|
||||
channelMonitorHandler *admin.ChannelMonitorHandler,
|
||||
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
||||
contentModerationHandler *admin.ContentModerationHandler,
|
||||
paymentHandler *admin.PaymentHandler,
|
||||
windsurfHandler *admin.WindsurfHandler,
|
||||
affiliateHandler *admin.AffiliateHandler,
|
||||
@ -68,6 +69,7 @@ func ProvideAdminHandlers(
|
||||
Channel: channelHandler,
|
||||
ChannelMonitor: channelMonitorHandler,
|
||||
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
||||
ContentModeration: contentModerationHandler,
|
||||
Payment: paymentHandler,
|
||||
Windsurf: windsurfHandler,
|
||||
Affiliate: affiliateHandler,
|
||||
@ -180,6 +182,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewChannelHandler,
|
||||
admin.NewChannelMonitorHandler,
|
||||
admin.NewChannelMonitorRequestTemplateHandler,
|
||||
admin.NewContentModerationHandler,
|
||||
admin.NewPaymentHandler,
|
||||
admin.NewAffiliateHandler,
|
||||
|
||||
|
||||
@ -317,6 +317,7 @@ const CLICurrentVersion = "2.1.92"
|
||||
// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
|
||||
// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
|
||||
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
|
||||
// - 不默认加入 redact-thinking,避免上游抹除 thinking 内容;客户端显式传入时由合并逻辑保留。
|
||||
func FullClaudeCodeMimicryBetas() []string {
|
||||
return []string{
|
||||
BetaClaudeCode,
|
||||
@ -324,7 +325,6 @@ func FullClaudeCodeMimicryBetas() []string {
|
||||
BetaInterleavedThinking,
|
||||
BetaPromptCachingScope,
|
||||
BetaEffort,
|
||||
BetaRedactThinking,
|
||||
BetaContextManagement,
|
||||
BetaExtendedCacheTTL,
|
||||
}
|
||||
|
||||
@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldName,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
|
||||
@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S
|
||||
|
||||
got, err := repo.GetByKeyForAuth(ctx, key.Key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.Name, got.Name)
|
||||
require.NotNil(t, got.Group)
|
||||
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
|
||||
}
|
||||
|
||||
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
@ -0,0 +1,71 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes"
|
||||
|
||||
type contentModerationHashCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache {
|
||||
return &contentModerationHashCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return nil
|
||||
}
|
||||
return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err()
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return false, nil
|
||||
}
|
||||
return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||
inputHash = strings.TrimSpace(inputHash)
|
||||
if c == nil || c.rdb == nil || inputHash == "" {
|
||||
return false, nil
|
||||
}
|
||||
deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return deleted > 0, nil
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||
if c == nil || c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if deleted == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||
if c == nil || c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||
}
|
||||
274
backend/internal/repository/content_moderation_repo.go
Normal file
274
backend/internal/repository/content_moderation_repo.go
Normal file
@ -0,0 +1,274 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type contentModerationRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
|
||||
return &contentModerationRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log == nil {
|
||||
return nil
|
||||
}
|
||||
categoryScores, err := json.Marshal(log.CategoryScores)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal moderation category scores: %w", err)
|
||||
}
|
||||
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal moderation thresholds: %w", err)
|
||||
}
|
||||
var userID any
|
||||
if log.UserID != nil {
|
||||
userID = *log.UserID
|
||||
}
|
||||
var apiKeyID any
|
||||
if log.APIKeyID != nil {
|
||||
apiKeyID = *log.APIKeyID
|
||||
}
|
||||
var groupID any
|
||||
if log.GroupID != nil {
|
||||
groupID = *log.GroupID
|
||||
}
|
||||
var latency any
|
||||
if log.UpstreamLatencyMS != nil {
|
||||
latency = *log.UpstreamLatencyMS
|
||||
}
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO content_moderation_logs (
|
||||
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
|
||||
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
|
||||
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
|
||||
violation_count, auto_banned, email_sent, queue_delay_ms
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9, $10, $11, $12, $13, $14, $15,
|
||||
$16::jsonb, $17::jsonb, $18, $19, $20,
|
||||
$21, $22, $23, $24
|
||||
) RETURNING id, created_at`,
|
||||
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
|
||||
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
|
||||
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
|
||||
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
|
||||
).Scan(&log.ID, &log.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert content moderation log: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
where, args := buildContentModerationLogWhere(filter)
|
||||
whereSQL := "WHERE " + strings.Join(where, " AND ")
|
||||
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
|
||||
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
|
||||
}
|
||||
|
||||
params := filter.Pagination
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
if params.PageSize > 100 {
|
||||
params.PageSize = 100
|
||||
}
|
||||
queryArgs := append([]any{}, args...)
|
||||
queryArgs = append(queryArgs, params.Limit(), params.Offset())
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
|
||||
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
|
||||
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
|
||||
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
|
||||
FROM content_moderation_logs l
|
||||
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
|
||||
ORDER BY l.created_at DESC, l.id DESC
|
||||
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
|
||||
queryArgs...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]service.ContentModerationLog, 0)
|
||||
for rows.Next() {
|
||||
var item service.ContentModerationLog
|
||||
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
|
||||
var scoresRaw, thresholdsRaw []byte
|
||||
if err := rows.Scan(
|
||||
&item.ID,
|
||||
&item.RequestID,
|
||||
&userID,
|
||||
&item.UserEmail,
|
||||
&apiKeyID,
|
||||
&item.APIKeyName,
|
||||
&groupID,
|
||||
&item.GroupName,
|
||||
&item.Endpoint,
|
||||
&item.Provider,
|
||||
&item.Model,
|
||||
&item.Mode,
|
||||
&item.Action,
|
||||
&item.Flagged,
|
||||
&item.HighestCategory,
|
||||
&item.HighestScore,
|
||||
&scoresRaw,
|
||||
&thresholdsRaw,
|
||||
&item.InputExcerpt,
|
||||
&latency,
|
||||
&item.Error,
|
||||
&item.ViolationCount,
|
||||
&item.AutoBanned,
|
||||
&item.EmailSent,
|
||||
&item.UserStatus,
|
||||
&queueDelay,
|
||||
&item.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
|
||||
}
|
||||
if userID.Valid {
|
||||
v := userID.Int64
|
||||
item.UserID = &v
|
||||
}
|
||||
if apiKeyID.Valid {
|
||||
v := apiKeyID.Int64
|
||||
item.APIKeyID = &v
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
item.GroupID = &v
|
||||
}
|
||||
if latency.Valid {
|
||||
v := int(latency.Int64)
|
||||
item.UpstreamLatencyMS = &v
|
||||
}
|
||||
if queueDelay.Valid {
|
||||
v := int(queueDelay.Int64)
|
||||
item.QueueDelayMS = &v
|
||||
}
|
||||
item.CategoryScores = map[string]float64{}
|
||||
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
|
||||
item.ThresholdSnapshot = map[string]float64{}
|
||||
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
|
||||
}
|
||||
return items, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||
if userID <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
WITH last_auto_ban AS (
|
||||
SELECT MAX(created_at) AS at
|
||||
FROM content_moderation_logs
|
||||
WHERE user_id = $1 AND auto_banned = TRUE
|
||||
)
|
||||
SELECT COUNT(*)
|
||||
FROM content_moderation_logs
|
||||
WHERE user_id = $1
|
||||
AND flagged = TRUE
|
||||
AND created_at >= $2
|
||||
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
|
||||
`, userID, since).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
|
||||
if r == nil || r.db == nil {
|
||||
return result, nil
|
||||
}
|
||||
hitExec, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM content_moderation_logs
|
||||
WHERE flagged = TRUE AND created_at < $1
|
||||
`, hitBefore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
|
||||
}
|
||||
result.DeletedHit, _ = hitExec.RowsAffected()
|
||||
|
||||
nonHitExec, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM content_moderation_logs
|
||||
WHERE flagged = FALSE AND created_at < $1
|
||||
`, nonHitBefore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
|
||||
}
|
||||
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
|
||||
|
||||
result.FinishedAt = time.Now()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func nullableIntPtr(value *int) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
|
||||
where := []string{"l.id IS NOT NULL"}
|
||||
args := make([]any, 0)
|
||||
add := func(expr string, value any) {
|
||||
args = append(args, value)
|
||||
where = append(where, fmt.Sprintf(expr, len(args)))
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
|
||||
case "hit", "flagged":
|
||||
where = append(where, "l.flagged = TRUE")
|
||||
case "blocked", "block":
|
||||
where = append(where, "l.action = 'block'")
|
||||
case "pass", "allow":
|
||||
where = append(where, "l.flagged = FALSE AND l.error = ''")
|
||||
case "error":
|
||||
where = append(where, "l.error <> ''")
|
||||
}
|
||||
if filter.GroupID != nil {
|
||||
add("l.group_id = $%d", *filter.GroupID)
|
||||
}
|
||||
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
|
||||
add("l.endpoint = $%d", endpoint)
|
||||
}
|
||||
if search := strings.TrimSpace(filter.Search); search != "" {
|
||||
like := "%" + search + "%"
|
||||
args = append(args, like, like, like, like, like)
|
||||
idx := len(args) - 4
|
||||
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
|
||||
}
|
||||
if filter.From != nil && !filter.From.IsZero() {
|
||||
add("l.created_at >= $%d", *filter.From)
|
||||
}
|
||||
if filter.To != nil && !filter.To.IsZero() {
|
||||
add("l.created_at <= $%d", *filter.To)
|
||||
}
|
||||
return where, args
|
||||
}
|
||||
@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if value < 0 {
|
||||
value = 0
|
||||
}
|
||||
res, err := r.sql.ExecContext(ctx,
|
||||
"UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||
value, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("batch set concurrency: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
return int(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
res, err := r.sql.ExecContext(ctx,
|
||||
"UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||
delta, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("batch add concurrency: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
return int(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelRepository,
|
||||
NewChannelMonitorRepository,
|
||||
NewChannelMonitorRequestTemplateRepository,
|
||||
NewContentModerationRepository,
|
||||
NewAffiliateRepository,
|
||||
|
||||
// Cache implementations
|
||||
@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRefreshTokenCache,
|
||||
NewErrorPassthroughCache,
|
||||
NewTLSFingerprintProfileCache,
|
||||
NewContentModerationHashCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
@ -646,12 +646,21 @@ func TestAPIContracts(t *testing.T) {
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"login_agreement_enabled": false,
|
||||
"login_agreement_mode": "modal",
|
||||
"login_agreement_updated_at": "2026-03-31",
|
||||
"login_agreement_documents": [
|
||||
{"id": "terms", "title": "服务条款", "content_md": ""},
|
||||
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
|
||||
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
|
||||
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
|
||||
],
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
"smtp_password_configured": true,
|
||||
"smtp_from_email": "no-reply@example.com",
|
||||
"smtp_from_name": "Sub2API",
|
||||
@ -685,6 +694,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"oidc_connect_userinfo_email_path": "",
|
||||
"oidc_connect_userinfo_id_path": "",
|
||||
"oidc_connect_userinfo_username_path": "",
|
||||
"github_oauth_enabled": false,
|
||||
"github_oauth_client_id": "",
|
||||
"github_oauth_client_secret_configured": false,
|
||||
"github_oauth_redirect_url": "",
|
||||
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||
"google_oauth_enabled": false,
|
||||
"google_oauth_client_id": "",
|
||||
"google_oauth_client_secret_configured": false,
|
||||
"google_oauth_redirect_url": "",
|
||||
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||
"ops_monitoring_enabled": false,
|
||||
"ops_realtime_monitoring_enabled": true,
|
||||
"ops_query_mode_default": "auto",
|
||||
@ -700,6 +719,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"auth_source_default_email_subscriptions": [],
|
||||
"auth_source_default_email_grant_on_signup": false,
|
||||
"auth_source_default_email_grant_on_first_bind": false,
|
||||
"auth_source_default_github_balance": 0,
|
||||
"auth_source_default_github_concurrency": 5,
|
||||
"auth_source_default_github_subscriptions": [],
|
||||
"auth_source_default_github_grant_on_signup": false,
|
||||
"auth_source_default_github_grant_on_first_bind": false,
|
||||
"auth_source_default_google_balance": 0,
|
||||
"auth_source_default_google_concurrency": 5,
|
||||
"auth_source_default_google_subscriptions": [],
|
||||
"auth_source_default_google_grant_on_signup": false,
|
||||
"auth_source_default_google_grant_on_first_bind": false,
|
||||
"auth_source_default_linuxdo_balance": 0,
|
||||
"auth_source_default_linuxdo_concurrency": 5,
|
||||
"auth_source_default_linuxdo_subscriptions": [],
|
||||
@ -792,6 +821,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"risk_control_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": false,
|
||||
"wechat_connect_app_id": "",
|
||||
@ -859,12 +889,21 @@ func TestAPIContracts(t *testing.T) {
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"invitation_code_enabled": false,
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "",
|
||||
"invitation_code_enabled": false,
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"login_agreement_enabled": false,
|
||||
"login_agreement_mode": "modal",
|
||||
"login_agreement_updated_at": "2026-03-31",
|
||||
"login_agreement_documents": [
|
||||
{"id": "terms", "title": "服务条款", "content_md": ""},
|
||||
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
|
||||
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
|
||||
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
|
||||
],
|
||||
"smtp_host": "",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "",
|
||||
"smtp_password_configured": false,
|
||||
"smtp_from_email": "",
|
||||
"smtp_from_name": "",
|
||||
@ -898,6 +937,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"oidc_connect_userinfo_email_path": "",
|
||||
"oidc_connect_userinfo_id_path": "",
|
||||
"oidc_connect_userinfo_username_path": "",
|
||||
"github_oauth_enabled": false,
|
||||
"github_oauth_client_id": "",
|
||||
"github_oauth_client_secret_configured": false,
|
||||
"github_oauth_redirect_url": "",
|
||||
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||
"google_oauth_enabled": false,
|
||||
"google_oauth_client_id": "",
|
||||
"google_oauth_client_secret_configured": false,
|
||||
"google_oauth_redirect_url": "",
|
||||
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||
"site_name": "Sub2API",
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subscription to API Conversion Platform",
|
||||
@ -983,6 +1032,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"channel_monitor_enabled": true,
|
||||
"channel_monitor_default_interval_seconds": 60,
|
||||
"available_channels_enabled": false,
|
||||
"risk_control_enabled": false,
|
||||
"affiliate_enabled": false,
|
||||
"wechat_connect_enabled": true,
|
||||
"wechat_connect_app_id": "wx-open-config",
|
||||
@ -1005,6 +1055,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"auth_source_default_email_subscriptions": [],
|
||||
"auth_source_default_email_grant_on_signup": false,
|
||||
"auth_source_default_email_grant_on_first_bind": false,
|
||||
"auth_source_default_github_balance": 0,
|
||||
"auth_source_default_github_concurrency": 5,
|
||||
"auth_source_default_github_subscriptions": [],
|
||||
"auth_source_default_github_grant_on_signup": false,
|
||||
"auth_source_default_github_grant_on_first_bind": false,
|
||||
"auth_source_default_google_balance": 0,
|
||||
"auth_source_default_google_concurrency": 5,
|
||||
"auth_source_default_google_subscriptions": [],
|
||||
"auth_source_default_google_grant_on_signup": false,
|
||||
"auth_source_default_google_grant_on_first_bind": false,
|
||||
"auth_source_default_linuxdo_balance": 0,
|
||||
"auth_source_default_linuxdo_concurrency": 5,
|
||||
"auth_source_default_linuxdo_subscriptions": [],
|
||||
@ -1123,7 +1183,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
@ -1294,6 +1354,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
panic("unexpected ExistsByEmail call")
|
||||
}
|
||||
|
||||
@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool {
|
||||
"/auth/oauth/wechat/callback",
|
||||
"/auth/oauth/wechat/payment/callback",
|
||||
"/auth/oauth/oidc/callback",
|
||||
"/auth/oauth/github/callback",
|
||||
"/auth/oauth/google/callback",
|
||||
"/auth/oauth/linuxdo/complete-registration",
|
||||
"/auth/oauth/wechat/complete-registration",
|
||||
"/auth/oauth/oidc/complete-registration",
|
||||
|
||||
@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) {
|
||||
path: "/api/v1/auth/oauth/oidc/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_github_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/github/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_github_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/github/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_google_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/google/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_google_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/google/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_exchange",
|
||||
enabled: "true",
|
||||
|
||||
@ -120,4 +120,6 @@ func registerRoutes(
|
||||
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster)
|
||||
|
||||
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
||||
|
||||
handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService)
|
||||
}
|
||||
|
||||
@ -95,11 +95,28 @@ func RegisterAdminRoutes(
|
||||
// 渠道监控
|
||||
registerChannelMonitorRoutes(admin, h)
|
||||
|
||||
// 风控中心
|
||||
registerContentModerationRoutes(admin, h)
|
||||
|
||||
// 邀请返利(专属用户管理)
|
||||
registerAffiliateRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
risk := admin.Group("/risk-control")
|
||||
{
|
||||
risk.GET("/config", h.Admin.ContentModeration.GetConfig)
|
||||
risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig)
|
||||
risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys)
|
||||
risk.GET("/status", h.Admin.ContentModeration.GetStatus)
|
||||
risk.GET("/logs", h.Admin.ContentModeration.ListLogs)
|
||||
risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser)
|
||||
risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash)
|
||||
risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes)
|
||||
}
|
||||
}
|
||||
|
||||
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
apiKeys := admin.Group("/api-keys")
|
||||
{
|
||||
@ -234,6 +251,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
|
||||
|
||||
// User attribute values
|
||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||
@ -270,6 +288,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||
accounts.POST("/import/codex-session", h.Admin.Account.ImportCodexSession)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
|
||||
@ -63,6 +63,22 @@ func RegisterAuthRoutes(
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart)
|
||||
auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback)
|
||||
auth.POST("/oauth/github/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteGitHubOAuthRegistration,
|
||||
)
|
||||
auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart)
|
||||
auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback)
|
||||
auth.POST("/oauth/google/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteGoogleOAuthRegistration,
|
||||
)
|
||||
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
|
||||
query := c.Request.URL.Query()
|
||||
query.Set("intent", "bind_current_user")
|
||||
|
||||
@ -33,6 +33,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||
@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||
cleaned := make([]int64, 0, len(userIDs))
|
||||
for _, uid := range userIDs {
|
||||
if uid > 0 {
|
||||
cleaned = append(cleaned, uid)
|
||||
}
|
||||
}
|
||||
if len(cleaned) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var affected int
|
||||
var err error
|
||||
switch mode {
|
||||
case "set":
|
||||
affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value)
|
||||
case "add":
|
||||
affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value)
|
||||
default:
|
||||
return 0, errors.New("invalid mode: must be 'set' or 'add'")
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, uid := range cleaned {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid)
|
||||
}
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
|
||||
@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
if s.existsErr != nil {
|
||||
return false, s.existsErr
|
||||
|
||||
@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||
|
||||
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
|
||||
@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
|
||||
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: apiKey.UserID,
|
||||
GroupID: apiKey.GroupID,
|
||||
Name: apiKey.Name,
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
UserID: snapshot.UserID,
|
||||
GroupID: snapshot.GroupID,
|
||||
Key: key,
|
||||
Name: snapshot.Name,
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
|
||||
@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
||||
UserID: 2,
|
||||
GroupID: &groupID,
|
||||
Key: "k-roundtrip",
|
||||
Name: "Audit Key",
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 2,
|
||||
@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
||||
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
||||
|
||||
require.NotNil(t, roundTrip)
|
||||
require.Equal(t, apiKey.Name, roundTrip.Name)
|
||||
require.NotNil(t, roundTrip.Group)
|
||||
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
|
||||
}
|
||||
|
||||
274
backend/internal/service/auth_email_oauth_auto.go
Normal file
274
backend/internal/service/auth_email_oauth_auto.go
Normal file
@ -0,0 +1,274 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
type EmailOAuthIdentityInput struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
ProviderSubject string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Username string
|
||||
DisplayName string
|
||||
AvatarURL string
|
||||
UpstreamMetadata map[string]any
|
||||
}
|
||||
|
||||
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) {
|
||||
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "")
|
||||
}
|
||||
|
||||
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation(
|
||||
ctx context.Context,
|
||||
input EmailOAuthIdentityInput,
|
||||
invitationCode string,
|
||||
affiliateCode string,
|
||||
) (*TokenPair, *User, error) {
|
||||
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode)
|
||||
}
|
||||
|
||||
func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
|
||||
ctx context.Context,
|
||||
input EmailOAuthIdentityInput,
|
||||
invitationCode string,
|
||||
affiliateCode string,
|
||||
) (*TokenPair, *User, error) {
|
||||
if s == nil || s.userRepo == nil || s.entClient == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||
if providerType != "github" && providerType != "google" {
|
||||
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
|
||||
}
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
if providerKey == "" {
|
||||
providerKey = providerType
|
||||
}
|
||||
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||
if providerSubject == "" {
|
||||
return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing")
|
||||
}
|
||||
if !input.EmailVerified {
|
||||
return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified")
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||
if email == "" || len(email) > 255 {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if isReservedEmail(email) {
|
||||
return nil, nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||
return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||
}
|
||||
|
||||
user := identityUser
|
||||
created := false
|
||||
if user == nil {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
created = true
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{
|
||||
ProviderType: providerType,
|
||||
ProviderKey: providerKey,
|
||||
ProviderSubject: providerSubject,
|
||||
Email: email,
|
||||
EmailVerified: input.EmailVerified,
|
||||
Username: input.Username,
|
||||
DisplayName: input.DisplayName,
|
||||
AvatarURL: input.AvatarURL,
|
||||
UpstreamMetadata: input.UpstreamMetadata,
|
||||
}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if user.Username == "" && strings.TrimSpace(input.Username) != "" {
|
||||
user.Username = strings.TrimSpace(input.Username)
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err)
|
||||
}
|
||||
}
|
||||
if !created {
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err)
|
||||
}
|
||||
}
|
||||
s.RecordSuccessfulLogin(ctx, user.ID)
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) {
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvitationCodeRequired) {
|
||||
return nil, ErrOAuthInvitationRequired
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, providerType)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
user := &User{
|
||||
Email: email,
|
||||
Username: strings.TrimSpace(username),
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: providerType,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
existing, loadErr := s.userRepo.GetByEmail(ctx, email)
|
||||
if loadErr != nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, providerType, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode)
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) {
|
||||
identity, err := s.entClient.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
user, err := s.userRepo.GetByID(ctx, identity.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error {
|
||||
metadata := map[string]any{
|
||||
"email": strings.TrimSpace(strings.ToLower(input.Email)),
|
||||
"email_verified": input.EmailVerified,
|
||||
}
|
||||
for key, value := range input.UpstreamMetadata {
|
||||
metadata[key] = value
|
||||
}
|
||||
if strings.TrimSpace(input.Username) != "" {
|
||||
metadata["username"] = strings.TrimSpace(input.Username)
|
||||
}
|
||||
if strings.TrimSpace(input.DisplayName) != "" {
|
||||
metadata["display_name"] = strings.TrimSpace(input.DisplayName)
|
||||
}
|
||||
if strings.TrimSpace(input.AvatarURL) != "" {
|
||||
metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL)
|
||||
}
|
||||
|
||||
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||
identity, err := s.entClient.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
if identity != nil {
|
||||
if identity.UserID != userID {
|
||||
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
_, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetMetadata(metadata).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
_, err = s.entClient.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetProviderSubject(providerSubject).
|
||||
SetMetadata(metadata).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
@ -10,6 +10,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func normalizeOAuthSignupSource(signupSource string) string {
|
||||
@ -17,7 +18,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
||||
switch signupSource {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
case "linuxdo", "wechat", "oidc", "github", "google":
|
||||
return signupSource
|
||||
default:
|
||||
return "email"
|
||||
@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
|
||||
// provider that has already returned a verified email address.
|
||||
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
|
||||
ctx context.Context,
|
||||
email string,
|
||||
password string,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
) (*TokenPair, *User, error) {
|
||||
if s == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" || len(email) > 255 {
|
||||
return nil, nil, ErrEmailVerifyRequired
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, nil, ErrEmailVerifyRequired
|
||||
}
|
||||
if isReservedEmail(email) {
|
||||
return nil, nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
|
||||
}
|
||||
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return nil, nil, ErrEmailExists
|
||||
}
|
||||
|
||||
hashedPassword, err := s.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
user := &User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return nil, nil, ErrEmailExists
|
||||
}
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
||||
// only after the pending OAuth flow has fully reached its last reversible step.
|
||||
func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
|
||||
@ -229,6 +229,67 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
|
||||
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
||||
}
|
||||
|
||||
func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
signupSource string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "github",
|
||||
email: "github@example.com",
|
||||
signupSource: " GitHub ",
|
||||
want: "github",
|
||||
},
|
||||
{
|
||||
name: "google",
|
||||
email: "google@example.com",
|
||||
signupSource: " Google ",
|
||||
want: "google",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 43}
|
||||
emailCache := &emailCacheStub{
|
||||
data: &VerificationCodeData{
|
||||
Code: "246810",
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
&redeemCodeRepoStub{},
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
context.Background(),
|
||||
tt.email,
|
||||
"secret-123",
|
||||
"246810",
|
||||
"",
|
||||
tt.signupSource,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, userRepo.created, 1)
|
||||
require.Equal(t, tt.want, userRepo.created[0].SignupSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 43}
|
||||
emailCache := &emailCacheStub{
|
||||
@ -256,7 +317,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
|
||||
"secret-123",
|
||||
"246810",
|
||||
"",
|
||||
"github",
|
||||
"unknown-provider",
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
return defaults.OIDC, true
|
||||
case "wechat":
|
||||
return defaults.WeChat, true
|
||||
case "github":
|
||||
return defaults.GitHub, true
|
||||
case "google":
|
||||
return defaults.Google, true
|
||||
default:
|
||||
return ProviderDefaultGrantSettings{}, false
|
||||
}
|
||||
|
||||
@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
64
backend/internal/service/codex_image_generation_bridge.go
Normal file
64
backend/internal/service/codex_image_generation_bridge.go
Normal file
@ -0,0 +1,64 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
const featureKeyCodexImageGenerationBridge = "codex_image_generation_bridge"
|
||||
|
||||
func boolOverridePtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func boolOverrideFromMap(values map[string]any, keys ...string) *bool {
|
||||
if values == nil {
|
||||
return nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
if v, ok := values[key].(bool); ok {
|
||||
return boolOverridePtr(v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func platformBoolOverride(values map[string]any, key string, platform string) *bool {
|
||||
if values == nil {
|
||||
return nil
|
||||
}
|
||||
if v, ok := values[key].(bool); ok {
|
||||
return boolOverridePtr(v)
|
||||
}
|
||||
raw, ok := values[key].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
platform = strings.TrimSpace(platform)
|
||||
if platform == "" {
|
||||
return nil
|
||||
}
|
||||
if v, ok := raw[platform].(bool); ok {
|
||||
return boolOverridePtr(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CodexImageGenerationBridgeOverride returns the channel-level override for Codex
|
||||
// image_generation bridge injection. Nil means follow the global/account policy.
|
||||
func (c *Channel) CodexImageGenerationBridgeOverride(platform string) *bool {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return platformBoolOverride(c.FeaturesConfig, featureKeyCodexImageGenerationBridge, platform)
|
||||
}
|
||||
|
||||
// CodexImageGenerationBridgeOverride returns the account-level override for Codex
|
||||
// image_generation bridge injection. Nil means follow the channel/global policy.
|
||||
func (a *Account) CodexImageGenerationBridgeOverride() *bool {
|
||||
if a == nil || a.Platform != PlatformOpenAI || a.Extra == nil {
|
||||
return nil
|
||||
}
|
||||
if override := boolOverrideFromMap(a.Extra, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled"); override != nil {
|
||||
return override
|
||||
}
|
||||
openaiConfig, _ := a.Extra[PlatformOpenAI].(map[string]any)
|
||||
return boolOverrideFromMap(openaiConfig, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled")
|
||||
}
|
||||
2048
backend/internal/service/content_moderation.go
Normal file
2048
backend/internal/service/content_moderation.go
Normal file
File diff suppressed because it is too large
Load Diff
117
backend/internal/service/content_moderation_email.go
Normal file
117
backend/internal/service/content_moderation_email.go
Normal file
@ -0,0 +1,117 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
|
||||
if log == nil {
|
||||
return ""
|
||||
}
|
||||
userName := strings.TrimSpace(log.UserEmail)
|
||||
if userName == "" && log.UserID != nil {
|
||||
userName = fmt.Sprintf("UID %d", *log.UserID)
|
||||
}
|
||||
threshold := cfg.BanThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = defaultContentModerationBanThreshold
|
||||
}
|
||||
statusBlock := ""
|
||||
if log.AutoBanned {
|
||||
statusBlock = `<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>`
|
||||
}
|
||||
return fmt.Sprintf(`<!doctype html>
|
||||
<html>
|
||||
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
|
||||
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
|
||||
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
|
||||
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
|
||||
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 风控提醒</div>
|
||||
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户触发内容审计规则</h1>
|
||||
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的 API 请求在内容审计中触发平台风控策略。详情如下。</p>
|
||||
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
|
||||
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">触发详情</h2>
|
||||
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d)</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
%s
|
||||
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
html.EscapeString(userName),
|
||||
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
|
||||
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
|
||||
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
|
||||
log.HighestScore,
|
||||
log.ViolationCount,
|
||||
threshold,
|
||||
statusBlock,
|
||||
html.EscapeString(siteName),
|
||||
)
|
||||
}
|
||||
|
||||
func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
|
||||
if log == nil {
|
||||
return ""
|
||||
}
|
||||
userName := strings.TrimSpace(log.UserEmail)
|
||||
if userName == "" && log.UserID != nil {
|
||||
userName = fmt.Sprintf("UID %d", *log.UserID)
|
||||
}
|
||||
threshold := cfg.BanThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = defaultContentModerationBanThreshold
|
||||
}
|
||||
return fmt.Sprintf(`<!doctype html>
|
||||
<html>
|
||||
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
|
||||
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
|
||||
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
|
||||
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
|
||||
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 账户封禁</div>
|
||||
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户已被自动禁用</h1>
|
||||
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。</p>
|
||||
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
|
||||
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">封禁详情</h2>
|
||||
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">封禁时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
|
||||
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d)</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>
|
||||
<p style="font-size:15px;line-height:1.8;color:#666;margin-top:24px;">如需申诉或恢复账号,请联系平台管理员处理。</p>
|
||||
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
html.EscapeString(userName),
|
||||
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
|
||||
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
|
||||
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
|
||||
log.HighestScore,
|
||||
log.ViolationCount,
|
||||
threshold,
|
||||
html.EscapeString(siteName),
|
||||
)
|
||||
}
|
||||
|
||||
func defaultContentModerationString(value string, fallback string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
320
backend/internal/service/content_moderation_input.go
Normal file
320
backend/internal/service/content_moderation_input.go
Normal file
@ -0,0 +1,320 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func ExtractContentModerationText(protocol string, body []byte) string {
|
||||
return ExtractContentModerationInput(protocol, body).Text
|
||||
}
|
||||
|
||||
func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return ContentModerationInput{}
|
||||
}
|
||||
var parts []string
|
||||
var images []string
|
||||
switch protocol {
|
||||
case ContentModerationProtocolAnthropicMessages:
|
||||
collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images)
|
||||
case ContentModerationProtocolOpenAIChat:
|
||||
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
|
||||
case ContentModerationProtocolOpenAIResponses:
|
||||
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
|
||||
case ContentModerationProtocolGemini:
|
||||
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
|
||||
case ContentModerationProtocolOpenAIImages:
|
||||
addModerationText(&parts, gjson.GetBytes(body, "prompt").String())
|
||||
collectContentValue(gjson.GetBytes(body, "images"), &parts, &images)
|
||||
default:
|
||||
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
|
||||
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
|
||||
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
|
||||
}
|
||||
out := ContentModerationInput{
|
||||
Text: normalizeContentModerationText(strings.Join(parts, "\n")),
|
||||
Images: normalizeModerationImages(images),
|
||||
}
|
||||
out.Normalize()
|
||||
return out
|
||||
}
|
||||
|
||||
func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) {
|
||||
if !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
}
|
||||
|
||||
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
|
||||
if !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
}
|
||||
|
||||
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||
switch {
|
||||
case !value.Exists():
|
||||
return
|
||||
case value.Type == gjson.String:
|
||||
if !isAnthropicSystemReminderText(value.String()) {
|
||||
addModerationText(parts, value.String())
|
||||
}
|
||||
case value.IsArray():
|
||||
value.ForEach(func(_, item gjson.Result) bool {
|
||||
collectAnthropicUserContentValue(item, parts, images)
|
||||
return true
|
||||
})
|
||||
case value.IsObject():
|
||||
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
|
||||
switch typ {
|
||||
case "", "text", "input_text", "message":
|
||||
if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) {
|
||||
addModerationText(parts, value.Get("text").String())
|
||||
}
|
||||
if value.Get("content").Exists() {
|
||||
collectAnthropicUserContentValue(value.Get("content"), parts, images)
|
||||
}
|
||||
case "image_url", "input_image", "image":
|
||||
collectContentValue(value, parts, images)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isAnthropicSystemReminderText(text string) bool {
|
||||
return strings.HasPrefix(strings.TrimSpace(text), "<system-reminder>")
|
||||
}
|
||||
|
||||
func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) {
|
||||
switch {
|
||||
case !input.Exists():
|
||||
return
|
||||
case input.Type == gjson.String:
|
||||
addModerationText(parts, input.String())
|
||||
case input.IsArray():
|
||||
var last gjson.Result
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
if isResponsesUserTextItem(item) {
|
||||
last = item
|
||||
}
|
||||
return true
|
||||
})
|
||||
if last.Exists() {
|
||||
collectContentValue(last.Get("content"), parts, images)
|
||||
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
|
||||
collectContentValue(last, parts, images)
|
||||
}
|
||||
}
|
||||
case input.IsObject():
|
||||
if isResponsesUserTextItem(input) {
|
||||
collectContentValue(input.Get("content"), parts, images)
|
||||
if input.Get("type").String() == "input_text" || input.Get("text").Exists() {
|
||||
collectContentValue(input, parts, images)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isResponsesUserTextItem(item gjson.Result) bool {
|
||||
role := strings.ToLower(strings.TrimSpace(item.Get("role").String()))
|
||||
if role == "user" {
|
||||
return responseItemHasModerationText(item)
|
||||
}
|
||||
if role != "" {
|
||||
return false
|
||||
}
|
||||
return responseItemHasModerationText(item)
|
||||
}
|
||||
|
||||
func responseItemHasModerationText(item gjson.Result) bool {
|
||||
var parts []string
|
||||
var images []string
|
||||
collectContentValue(item.Get("content"), &parts, &images)
|
||||
if item.Get("type").String() == "input_text" || item.Get("text").Exists() {
|
||||
collectContentValue(item, &parts, &images)
|
||||
}
|
||||
return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0
|
||||
}
|
||||
|
||||
func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) {
|
||||
if !contents.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
contents.ForEach(func(_, content gjson.Result) bool {
|
||||
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
|
||||
if role == "" || role == "user" {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
if arr := content.Get("parts"); arr.IsArray() {
|
||||
arr.ForEach(func(_, part gjson.Result) bool {
|
||||
addModerationText(&candidate, part.Get("text").String())
|
||||
addGeminiModerationImage(&candidateImages, part)
|
||||
return true
|
||||
})
|
||||
}
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
}
|
||||
|
||||
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||
switch {
|
||||
case !value.Exists():
|
||||
return
|
||||
case value.Type == gjson.String:
|
||||
addModerationText(parts, value.String())
|
||||
case value.IsArray():
|
||||
value.ForEach(func(_, item gjson.Result) bool {
|
||||
collectContentValue(item, parts, images)
|
||||
return true
|
||||
})
|
||||
case value.IsObject():
|
||||
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
|
||||
addModerationImage(images, value.Get("image_url.url").String())
|
||||
addModerationImage(images, value.Get("image_url").String())
|
||||
addModerationImage(images, value.Get("url").String())
|
||||
addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String())
|
||||
addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String())
|
||||
addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String())
|
||||
addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String())
|
||||
addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String())
|
||||
addModerationImage(images, value.Get("source.data").String())
|
||||
addModerationImage(images, value.Get("data").String())
|
||||
addModerationImage(images, value.Get("base64").String())
|
||||
switch typ {
|
||||
case "", "text", "input_text", "message":
|
||||
if value.Get("text").Exists() {
|
||||
addModerationText(parts, value.Get("text").String())
|
||||
}
|
||||
if value.Get("content").Exists() {
|
||||
collectContentValue(value.Get("content"), parts, images)
|
||||
}
|
||||
case "image_url", "input_image", "image":
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addGeminiModerationImage(images *[]string, part gjson.Result) {
|
||||
if inlineData := part.Get("inline_data"); inlineData.IsObject() {
|
||||
mimeType := strings.TrimSpace(inlineData.Get("mime_type").String())
|
||||
data := strings.TrimSpace(inlineData.Get("data").String())
|
||||
if mimeType != "" && data != "" {
|
||||
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||
}
|
||||
}
|
||||
if inlineData := part.Get("inlineData"); inlineData.IsObject() {
|
||||
mimeType := strings.TrimSpace(inlineData.Get("mimeType").String())
|
||||
data := strings.TrimSpace(inlineData.Get("data").String())
|
||||
if mimeType != "" && data != "" {
|
||||
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||
}
|
||||
}
|
||||
addModerationImage(images, part.Get("file_data.file_uri").String())
|
||||
addModerationImage(images, part.Get("fileData.fileUri").String())
|
||||
}
|
||||
|
||||
func addModerationImageData(images *[]string, mimeType string, data string) {
|
||||
mimeType = strings.TrimSpace(mimeType)
|
||||
data = strings.TrimSpace(data)
|
||||
if mimeType == "" || data == "" {
|
||||
return
|
||||
}
|
||||
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||
}
|
||||
|
||||
func addModerationImage(images *[]string, image string) {
|
||||
image = strings.TrimSpace(image)
|
||||
if image == "" {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") {
|
||||
*images = append(*images, image)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeModerationImages(images []string) []string {
|
||||
out := make([]string, 0, len(images))
|
||||
seen := make(map[string]struct{}, len(images))
|
||||
for _, image := range images {
|
||||
image = strings.TrimSpace(image)
|
||||
if image == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[image]; ok {
|
||||
continue
|
||||
}
|
||||
seen[image] = struct{}{}
|
||||
out = append(out, image)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func limitContentModerationImages(images []string) []string {
|
||||
if len(images) <= maxContentModerationInputImages {
|
||||
return images
|
||||
}
|
||||
idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(images))))
|
||||
if err != nil {
|
||||
return images[:maxContentModerationInputImages]
|
||||
}
|
||||
return []string{images[int(idx.Int64())]}
|
||||
}
|
||||
|
||||
func addModerationText(parts *[]string, text string) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(text, "<system-reminder>") {
|
||||
return
|
||||
}
|
||||
*parts = append(*parts, text)
|
||||
}
|
||||
|
||||
func normalizeContentModerationText(text string) string {
|
||||
return strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||
}
|
||||
37
backend/internal/service/content_moderation_redact.go
Normal file
37
backend/internal/service/content_moderation_redact.go
Normal file
@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var contentModerationSecretPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\bhttps?://[^\s"'<>,。;、]+`),
|
||||
regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`),
|
||||
regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`),
|
||||
regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`),
|
||||
regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`),
|
||||
regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`),
|
||||
regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`),
|
||||
regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`),
|
||||
regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`),
|
||||
}
|
||||
|
||||
func redactContentModerationSecrets(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
out := text
|
||||
for idx, pattern := range contentModerationSecretPatterns {
|
||||
switch idx {
|
||||
case 1:
|
||||
out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`)
|
||||
case 2:
|
||||
out = pattern.ReplaceAllString(out, `${1}[已脱敏]`)
|
||||
default:
|
||||
out = pattern.ReplaceAllString(out, `[已脱敏]`)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
1006
backend/internal/service/content_moderation_test.go
Normal file
1006
backend/internal/service/content_moderation_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -108,6 +108,12 @@ const (
|
||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
|
||||
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON)
|
||||
SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款
|
||||
SettingKeyLoginAgreementMode = "login_agreement_mode" // 条款确认展示模式:modal / checkbox
|
||||
SettingKeyLoginAgreementUpdatedAt = "login_agreement_updated_at" // 条款更新日期(展示用)
|
||||
SettingKeyLoginAgreementDocuments = "login_agreement_documents" // 条款文档列表(JSON,Markdown 内容)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@ -174,6 +180,18 @@ const (
|
||||
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
|
||||
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
|
||||
|
||||
// GitHub / Google 邮箱快捷登录设置
|
||||
SettingKeyGitHubOAuthEnabled = "github_oauth_enabled"
|
||||
SettingKeyGitHubOAuthClientID = "github_oauth_client_id"
|
||||
SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret"
|
||||
SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url"
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url"
|
||||
SettingKeyGoogleOAuthEnabled = "google_oauth_enabled"
|
||||
SettingKeyGoogleOAuthClientID = "google_oauth_client_id"
|
||||
SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret"
|
||||
SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url"
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
@ -217,6 +235,16 @@ const (
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
|
||||
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
|
||||
|
||||
// 管理员 API Key
|
||||
|
||||
@ -124,6 +124,24 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||
require.Contains(t, got, "fast-mode-2026-02-01")
|
||||
}
|
||||
|
||||
func TestFullClaudeCodeMimicryBetas_DoesNotDefaultRedactThinking(t *testing.T) {
|
||||
required := claude.FullClaudeCodeMimicryBetas()
|
||||
|
||||
require.NotContains(t, required, claude.BetaRedactThinking)
|
||||
require.Contains(t, required, claude.BetaClaudeCode)
|
||||
require.Contains(t, required, claude.BetaOAuth)
|
||||
require.Contains(t, required, claude.BetaInterleavedThinking)
|
||||
}
|
||||
|
||||
func TestMergeAnthropicBetaDropping_PreservesIncomingRedactThinking(t *testing.T) {
|
||||
required := claude.FullClaudeCodeMimicryBetas()
|
||||
incoming := claude.BetaRedactThinking
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, droppedBetaSet())
|
||||
|
||||
require.Contains(t, got, claude.BetaRedactThinking)
|
||||
}
|
||||
|
||||
func TestDroppedBetaSet(t *testing.T) {
|
||||
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
|
||||
base := droppedBetaSet()
|
||||
|
||||
@ -5445,6 +5445,12 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
flusher.Flush()
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
if clientDisconnected && streamInterval > 0 {
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) >= streamInterval {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
|
||||
@ -440,6 +440,21 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont
|
||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isCodexImageGenerationBridgeEnabled(ctx context.Context, account *Account, apiKey *APIKey) bool {
|
||||
if override := account.CodexImageGenerationBridgeOverride(); override != nil {
|
||||
return *override
|
||||
}
|
||||
if s != nil && s.channelService != nil && apiKey != nil && apiKey.GroupID != nil {
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *apiKey.GroupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to resolve codex image generation bridge channel override", "group_id", *apiKey.GroupID, "error", err)
|
||||
} else if override := ch.CodexImageGenerationBridgeOverride(PlatformOpenAI); override != nil {
|
||||
return *override
|
||||
}
|
||||
}
|
||||
return s != nil && s.cfg != nil && s.cfg.Gateway.CodexImageGenerationBridgeEnabled
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||
return false
|
||||
@ -2059,6 +2074,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
if apiKey != nil {
|
||||
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
|
||||
}
|
||||
codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey)
|
||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
@ -2128,7 +2144,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||
}
|
||||
|
||||
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
|
||||
@ -2139,7 +2155,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||
}
|
||||
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
|
||||
|
||||
@ -83,12 +83,14 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowImages bool
|
||||
wantInjected bool
|
||||
name string
|
||||
allowImages bool
|
||||
bridgeEnabled bool
|
||||
wantInjected bool
|
||||
}{
|
||||
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
|
||||
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
|
||||
{name: "disabled group skips injection", allowImages: false, bridgeEnabled: true, wantInjected: false},
|
||||
{name: "enabled group skips injection by default", allowImages: true, bridgeEnabled: false, wantInjected: false},
|
||||
{name: "enabled group injects image tool when bridge enabled", allowImages: true, bridgeEnabled: true, wantInjected: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -101,6 +103,7 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.bridgeEnabled
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
@ -117,6 +120,154 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_ExplicitImageToolWorksWithBridgeDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_explicit_image","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
body := []byte(`{"model":"gpt-5.4","input":"draw","stream":false,"tools":[{"type":"image_generation","format":"jpeg"}]}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
|
||||
require.Equal(t, "jpeg", gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").output_format`).String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").format`).Exists())
|
||||
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||
require.NotContains(t, instructions, "image_generation")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForward_ChannelBridgeOverrideEnablesCodexInjection(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_channel_bridge","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||
groupID := int64(4242)
|
||||
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, &Channel{
|
||||
ID: 9001,
|
||||
Status: StatusActive,
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||
},
|
||||
})
|
||||
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
|
||||
account := newOpenAIImageGenerationControlTestAccount()
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
|
||||
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||
require.Contains(t, instructions, "image_generation")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexImageGenerationBridgeOverridePrecedence(t *testing.T) {
|
||||
groupID := int64(4242)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
global bool
|
||||
channel *Channel
|
||||
account *Account
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "global default enables bridge",
|
||||
global: true,
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "channel true overrides disabled global",
|
||||
global: false,
|
||||
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||
}},
|
||||
account: &Account{Platform: PlatformOpenAI},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "channel false overrides enabled global",
|
||||
global: true,
|
||||
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
|
||||
}},
|
||||
account: &Account{Platform: PlatformOpenAI},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "account false overrides channel and global true",
|
||||
global: true,
|
||||
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||
}},
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{featureKeyCodexImageGenerationBridge: false},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nested account true overrides channel false",
|
||||
global: false,
|
||||
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
|
||||
}},
|
||||
account: &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Extra: map[string]any{
|
||||
PlatformOpenAI: map[string]any{"codex_image_generation_bridge_enabled": true},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non openai account extra is ignored",
|
||||
global: false,
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Extra: map[string]any{featureKeyCodexImageGenerationBridge: true},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.global
|
||||
if tt.channel != nil {
|
||||
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, tt.channel)
|
||||
}
|
||||
apiKey := &APIKey{GroupID: &groupID}
|
||||
|
||||
got := svc.isCodexImageGenerationBridgeEnabled(context.Background(), tt.account, apiKey)
|
||||
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -180,6 +331,18 @@ func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder)
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlChannelService(groupID int64, ch *Channel) *ChannelService {
|
||||
svc := &ChannelService{}
|
||||
cache := newEmptyChannelCache()
|
||||
if ch != nil {
|
||||
cache.channelByGroupID[groupID] = ch
|
||||
cache.byID[ch.ID] = ch
|
||||
}
|
||||
cache.loadedAt = time.Now()
|
||||
svc.cache.Store(cache)
|
||||
return svc
|
||||
}
|
||||
|
||||
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
@ -90,6 +90,69 @@ type OpenAIImagesRequest struct {
|
||||
bodyHash string
|
||||
}
|
||||
|
||||
func (r *OpenAIImagesRequest) ModerationBody() []byte {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{}
|
||||
if prompt := strings.TrimSpace(r.Prompt); prompt != "" {
|
||||
payload["prompt"] = prompt
|
||||
}
|
||||
images := r.moderationImages()
|
||||
if len(images) > 0 {
|
||||
payload["images"] = images
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (r *OpenAIImagesRequest) moderationImages() []map[string]string {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1)
|
||||
for _, imageURL := range r.InputImageURLs {
|
||||
imageURL = strings.TrimSpace(imageURL)
|
||||
if imageURL != "" {
|
||||
images = append(images, map[string]string{"image_url": imageURL})
|
||||
}
|
||||
}
|
||||
for _, upload := range r.Uploads {
|
||||
if dataURL := upload.ModerationDataURL(); dataURL != "" {
|
||||
images = append(images, map[string]string{"image_url": dataURL})
|
||||
}
|
||||
}
|
||||
if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" {
|
||||
images = append(images, map[string]string{"image_url": maskURL})
|
||||
}
|
||||
if r.MaskUpload != nil {
|
||||
if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" {
|
||||
images = append(images, map[string]string{"image_url": dataURL})
|
||||
}
|
||||
}
|
||||
return images
|
||||
}
|
||||
|
||||
func (u OpenAIImagesUpload) ModerationDataURL() string {
|
||||
if len(u.Data) == 0 {
|
||||
return ""
|
||||
}
|
||||
contentType := strings.TrimSpace(u.ContentType)
|
||||
if contentType == "" {
|
||||
contentType = http.DetectContentType(u.Data)
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data))
|
||||
}
|
||||
|
||||
func (r *OpenAIImagesRequest) IsEdits() bool {
|
||||
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
|
||||
}
|
||||
|
||||
@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesEditsEndpoint,
|
||||
Prompt: "replace background",
|
||||
InputImageURLs: []string{"https://example.com/source.png"},
|
||||
MaskImageURL: "https://example.com/mask.png",
|
||||
}
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
|
||||
|
||||
require.Equal(t, "replace background", input.Text)
|
||||
require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images)
|
||||
}
|
||||
|
||||
func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesEditsEndpoint,
|
||||
Prompt: "replace background",
|
||||
Uploads: []OpenAIImagesUpload{{
|
||||
FieldName: "image",
|
||||
FileName: "source.png",
|
||||
ContentType: "image/png",
|
||||
Data: []byte("fake-image-bytes"),
|
||||
}},
|
||||
MaskUpload: &OpenAIImagesUpload{
|
||||
FieldName: "mask",
|
||||
FileName: "mask.png",
|
||||
ContentType: "image/png",
|
||||
Data: []byte("fake-mask-bytes"),
|
||||
},
|
||||
}
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
|
||||
|
||||
require.Equal(t, "replace background", input.Text)
|
||||
require.Equal(t, []string{
|
||||
"data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==",
|
||||
"data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz",
|
||||
}, input.Images)
|
||||
|
||||
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
|
||||
require.Equal(t, "replace background", log.InputExcerpt)
|
||||
require.NotContains(t, log.InputExcerpt, "ZmFrZS")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -52,3 +52,47 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess
|
||||
require.Equal(t, "client-id-1", info.ClientID)
|
||||
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
||||
}
|
||||
|
||||
func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) {
|
||||
refresher := NewOpenAITokenRefresher(nil, nil)
|
||||
expiresAt := time.Now().Add(time.Minute).UTC().Format(time.RFC3339)
|
||||
|
||||
withoutRT := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
require.False(t, refresher.NeedsRefresh(withoutRT, 5*time.Minute))
|
||||
|
||||
withRT := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
require.True(t, refresher.NeedsRefresh(withRT, 5*time.Minute))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NoRefreshTokenExpiredAccessTokenReturnsError(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "expired-access-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, token)
|
||||
require.Contains(t, err.Error(), "refresh_token is missing")
|
||||
}
|
||||
|
||||
@ -152,6 +152,12 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||
if expiresAt != nil && !time.Now().Before(*expiresAt) {
|
||||
return "", errors.New("openai access_token expired and refresh_token is missing")
|
||||
}
|
||||
needsRefresh = false
|
||||
}
|
||||
refreshFailed := false
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
|
||||
@ -424,8 +424,9 @@ func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@ -650,8 +651,9 @@ func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@ -819,8 +821,9 @@ func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@ -848,8 +851,9 @@ func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
@ -875,8 +879,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T)
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
@ -911,8 +916,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
"access_token": "fallback-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct {
|
||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||
InitialRequestModel string
|
||||
BeforeTurn func(turn int) error
|
||||
BeforeRequest func(turn int, payload []byte, originalModel string) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
}
|
||||
|
||||
@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return true
|
||||
}
|
||||
for {
|
||||
if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil {
|
||||
if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
|
||||
if err := hooks.BeforeTurn(turn); err != nil {
|
||||
return err
|
||||
|
||||
@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
if msgType != coderws.MessageText {
|
||||
return payload, nil, nil
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil {
|
||||
turnNo := int(completedTurns.Load()) + 1
|
||||
if turnNo < 2 {
|
||||
turnNo = 2
|
||||
}
|
||||
requestModel := usageMeta.requestModelForFrame(payload)
|
||||
if requestModel == "" {
|
||||
requestModel = capturedSessionModel
|
||||
}
|
||||
if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil {
|
||||
return payload, nil, err
|
||||
}
|
||||
}
|
||||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||||
// session.update 修改 session-level model(Realtime /
|
||||
// Responses WS 协议允许),如果不刷新就会出现
|
||||
|
||||
@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
case redeemActionRedeem:
|
||||
// Code exists but unused — skip creation, proceed to redeem
|
||||
}
|
||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
||||
if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil {
|
||||
return fmt.Errorf("redeem balance: %w", err)
|
||||
}
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
|
||||
@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
|
||||
nil,
|
||||
client,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentOrderLifecycleQueryProvider{
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@ -28,6 +29,15 @@ const (
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
type ctxKeySkipRedeemAffiliate struct{}
|
||||
|
||||
// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level
|
||||
// affiliate rebate. Used by payment fulfillment which handles rebate separately
|
||||
// via applyAffiliateRebateForOrder (with audit-log deduplication).
|
||||
func ContextSkipRedeemAffiliate(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true)
|
||||
}
|
||||
|
||||
// RedeemCache defines cache operations for redeem service
|
||||
type RedeemCache interface {
|
||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
@ -80,6 +90,7 @@ type RedeemService struct {
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
affiliateService *AffiliateService
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
@ -91,6 +102,7 @@ func NewRedeemService(
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
affiliateService *AffiliateService,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
@ -100,6 +112,7 @@ func NewRedeemService(
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
affiliateService: affiliateService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
// 事务提交成功后失效缓存
|
||||
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
||||
|
||||
// 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果)
|
||||
if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 {
|
||||
s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value)
|
||||
}
|
||||
|
||||
// 重新获取更新后的兑换码
|
||||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||||
if err != nil {
|
||||
@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) {
|
||||
if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil {
|
||||
return
|
||||
}
|
||||
if s.affiliateService == nil {
|
||||
return
|
||||
}
|
||||
if !s.affiliateService.IsEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err)
|
||||
return
|
||||
}
|
||||
if rebate > 0 {
|
||||
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取兑换码
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
|
||||
@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@ -129,6 +130,8 @@ type AuthSourceDefaultSettings struct {
|
||||
LinuxDo ProviderDefaultGrantSettings
|
||||
OIDC ProviderDefaultGrantSettings
|
||||
WeChat ProviderDefaultGrantSettings
|
||||
GitHub ProviderDefaultGrantSettings
|
||||
Google ProviderDefaultGrantSettings
|
||||
ForceEmailOnThirdPartySignup bool
|
||||
}
|
||||
|
||||
@ -169,6 +172,20 @@ var (
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
}
|
||||
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultGitHubBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||
}
|
||||
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultGoogleBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
@ -177,8 +194,151 @@ const (
|
||||
defaultWeChatConnectMode = "open"
|
||||
defaultWeChatConnectScopes = "snsapi_login"
|
||||
defaultWeChatConnectFrontend = "/auth/wechat/callback"
|
||||
defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize"
|
||||
defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token"
|
||||
defaultGitHubOAuthUserInfo = "https://api.github.com/user"
|
||||
defaultGitHubOAuthEmails = "https://api.github.com/user/emails"
|
||||
defaultGitHubOAuthScopes = "read:user user:email"
|
||||
defaultGitHubOAuthFrontend = "/auth/oauth/callback"
|
||||
defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token"
|
||||
defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||
defaultGoogleOAuthScopes = "openid email profile"
|
||||
defaultGoogleOAuthFrontend = "/auth/oauth/callback"
|
||||
defaultLoginAgreementMode = "modal"
|
||||
defaultLoginAgreementDate = "2026-03-31"
|
||||
)
|
||||
|
||||
func normalizeLoginAgreementMode(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "checkbox":
|
||||
return "checkbox"
|
||||
default:
|
||||
return defaultLoginAgreementMode
|
||||
}
|
||||
}
|
||||
|
||||
func defaultLoginAgreementDocuments() []LoginAgreementDocument {
|
||||
return []LoginAgreementDocument{
|
||||
{
|
||||
ID: "terms",
|
||||
Title: "服务条款",
|
||||
ContentMD: "",
|
||||
},
|
||||
{
|
||||
ID: "usage-policy",
|
||||
Title: "使用政策",
|
||||
ContentMD: "",
|
||||
},
|
||||
{
|
||||
ID: "supported-regions",
|
||||
Title: "支持的国家和地区",
|
||||
ContentMD: "",
|
||||
},
|
||||
{
|
||||
ID: "service-specific-terms",
|
||||
Title: "服务特定条款",
|
||||
ContentMD: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLoginAgreementDocumentID(raw string) string {
|
||||
raw = strings.ToLower(strings.TrimSpace(raw))
|
||||
var b strings.Builder
|
||||
lastSeparator := false
|
||||
for _, r := range raw {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') {
|
||||
_, _ = b.WriteRune(r)
|
||||
lastSeparator = false
|
||||
continue
|
||||
}
|
||||
if r == '-' || r == '_' || r == ' ' || r == '.' || r == '/' {
|
||||
if !lastSeparator && b.Len() > 0 {
|
||||
if r == '_' {
|
||||
_, _ = b.WriteRune('_')
|
||||
} else {
|
||||
_, _ = b.WriteRune('-')
|
||||
}
|
||||
lastSeparator = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Trim(b.String(), "-_")
|
||||
}
|
||||
|
||||
func normalizeLoginAgreementDocuments(docs []LoginAgreementDocument) []LoginAgreementDocument {
|
||||
normalized := make([]LoginAgreementDocument, 0, len(docs))
|
||||
seen := make(map[string]int, len(docs))
|
||||
for i, doc := range docs {
|
||||
title := strings.TrimSpace(doc.Title)
|
||||
content := strings.TrimSpace(doc.ContentMD)
|
||||
if title == "" && content == "" {
|
||||
continue
|
||||
}
|
||||
id := normalizeLoginAgreementDocumentID(doc.ID)
|
||||
if id == "" {
|
||||
sum := sha256.Sum256([]byte(fmt.Sprintf("%d:%s:%s", i, title, content)))
|
||||
id = hex.EncodeToString(sum[:])[:12]
|
||||
}
|
||||
baseID := id
|
||||
for suffix := 2; seen[id] > 0; suffix++ {
|
||||
id = fmt.Sprintf("%s-%d", baseID, suffix)
|
||||
}
|
||||
seen[id]++
|
||||
normalized = append(normalized, LoginAgreementDocument{
|
||||
ID: id,
|
||||
Title: title,
|
||||
ContentMD: content,
|
||||
})
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func parseLoginAgreementDocuments(raw string) []LoginAgreementDocument {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return defaultLoginAgreementDocuments()
|
||||
}
|
||||
var docs []LoginAgreementDocument
|
||||
if err := json.Unmarshal([]byte(raw), &docs); err != nil {
|
||||
return defaultLoginAgreementDocuments()
|
||||
}
|
||||
docs = normalizeLoginAgreementDocuments(docs)
|
||||
if len(docs) == 0 {
|
||||
return defaultLoginAgreementDocuments()
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
func marshalLoginAgreementDocuments(docs []LoginAgreementDocument) (string, error) {
|
||||
normalized := normalizeLoginAgreementDocuments(docs)
|
||||
if len(normalized) == 0 {
|
||||
normalized = defaultLoginAgreementDocuments()
|
||||
}
|
||||
b, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal login agreement documents: %w", err)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func buildLoginAgreementRevision(updatedAt string, docs []LoginAgreementDocument) string {
|
||||
normalized := normalizeLoginAgreementDocuments(docs)
|
||||
payload, err := json.Marshal(struct {
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Documents []LoginAgreementDocument `json:"documents"`
|
||||
}{
|
||||
UpdatedAt: strings.TrimSpace(updatedAt),
|
||||
Documents: normalized,
|
||||
})
|
||||
if err != nil {
|
||||
payload = []byte(strings.TrimSpace(updatedAt))
|
||||
}
|
||||
sum := sha256.Sum256(payload)
|
||||
return hex.EncodeToString(sum[:])[:16]
|
||||
}
|
||||
|
||||
func normalizeWeChatConnectModeSetting(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "mp":
|
||||
@ -411,6 +571,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyPasswordResetEnabled,
|
||||
SettingKeyInvitationCodeEnabled,
|
||||
SettingKeyTotpEnabled,
|
||||
SettingKeyLoginAgreementEnabled,
|
||||
SettingKeyLoginAgreementMode,
|
||||
SettingKeyLoginAgreementUpdatedAt,
|
||||
SettingKeyLoginAgreementDocuments,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
@ -448,6 +612,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingPaymentEnabled,
|
||||
SettingKeyOIDCConnectEnabled,
|
||||
SettingKeyOIDCConnectProviderName,
|
||||
SettingKeyGitHubOAuthEnabled,
|
||||
SettingKeyGitHubOAuthClientID,
|
||||
SettingKeyGitHubOAuthClientSecret,
|
||||
SettingKeyGoogleOAuthEnabled,
|
||||
SettingKeyGoogleOAuthClientID,
|
||||
SettingKeyGoogleOAuthClientSecret,
|
||||
SettingKeyBalanceLowNotifyEnabled,
|
||||
SettingKeyBalanceLowNotifyThreshold,
|
||||
SettingKeyBalanceLowNotifyRechargeURL,
|
||||
@ -456,6 +626,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyChannelMonitorDefaultIntervalSeconds,
|
||||
SettingKeyAvailableChannelsEnabled,
|
||||
SettingKeyAffiliateEnabled,
|
||||
SettingKeyRiskControlEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@ -482,6 +653,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
if oidcProviderName == "" {
|
||||
oidcProviderName = "OIDC"
|
||||
}
|
||||
gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github")
|
||||
googleEnabled := s.emailOAuthPublicEnabled(settings, "google")
|
||||
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
|
||||
|
||||
// Password reset requires email verification to be enabled
|
||||
@ -494,6 +667,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
settings[SettingKeyTableDefaultPageSize],
|
||||
settings[SettingKeyTablePageSizeOptions],
|
||||
)
|
||||
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
|
||||
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
|
||||
if loginAgreementUpdatedAt == "" {
|
||||
loginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||
}
|
||||
|
||||
var balanceLowNotifyThreshold float64
|
||||
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
|
||||
@ -509,6 +687,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
PasswordResetEnabled: passwordResetEnabled,
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0,
|
||||
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments),
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
@ -534,6 +717,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
||||
OIDCOAuthEnabled: oidcEnabled,
|
||||
OIDCOAuthProviderName: oidcProviderName,
|
||||
GitHubOAuthEnabled: gitHubEnabled,
|
||||
GoogleOAuthEnabled: googleEnabled,
|
||||
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
|
||||
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
|
||||
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
|
||||
@ -545,6 +730,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
|
||||
|
||||
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
|
||||
|
||||
RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -647,43 +834,50 @@ func (s *SettingService) SetVersion(version string) {
|
||||
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
|
||||
// drift automatically (see setting_service_injection_test.go).
|
||||
type PublicSettingsInjectionPayload struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||
LoginAgreementRevision string `json:"login_agreement_revision"`
|
||||
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
|
||||
// Feature flags — MUST match the opt-in/opt-out registry in
|
||||
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
|
||||
@ -692,6 +886,7 @@ type PublicSettingsInjectionPayload struct {
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
||||
@ -710,6 +905,11 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||
LoginAgreementMode: settings.LoginAgreementMode,
|
||||
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||
LoginAgreementRevision: settings.LoginAgreementRevision,
|
||||
LoginAgreementDocuments: settings.LoginAgreementDocuments,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
@ -733,6 +933,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: s.version,
|
||||
@ -745,6 +947,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
RiskControlEnabled: settings.RiskControlEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -806,6 +1009,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string
|
||||
return openReady || mpReady, openReady, mpReady, mobileReady
|
||||
}
|
||||
|
||||
func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
cfg := config.EmailOAuthProviderConfig{
|
||||
AuthorizeURL: defaultGitHubOAuthAuthorize,
|
||||
TokenURL: defaultGitHubOAuthToken,
|
||||
UserInfoURL: defaultGitHubOAuthUserInfo,
|
||||
EmailsURL: defaultGitHubOAuthEmails,
|
||||
Scopes: defaultGitHubOAuthScopes,
|
||||
FrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||
}
|
||||
if s != nil && s.cfg != nil {
|
||||
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth)
|
||||
}
|
||||
return cfg
|
||||
case "google":
|
||||
cfg := config.EmailOAuthProviderConfig{
|
||||
AuthorizeURL: defaultGoogleOAuthAuthorize,
|
||||
TokenURL: defaultGoogleOAuthToken,
|
||||
UserInfoURL: defaultGoogleOAuthUserInfo,
|
||||
Scopes: defaultGoogleOAuthScopes,
|
||||
FrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||
}
|
||||
if s != nil && s.cfg != nil {
|
||||
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth)
|
||||
}
|
||||
return cfg
|
||||
default:
|
||||
return config.EmailOAuthProviderConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig {
|
||||
base.Enabled = override.Enabled
|
||||
if strings.TrimSpace(override.ClientID) != "" {
|
||||
base.ClientID = strings.TrimSpace(override.ClientID)
|
||||
}
|
||||
if strings.TrimSpace(override.ClientSecret) != "" {
|
||||
base.ClientSecret = strings.TrimSpace(override.ClientSecret)
|
||||
}
|
||||
if strings.TrimSpace(override.AuthorizeURL) != "" {
|
||||
base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL)
|
||||
}
|
||||
if strings.TrimSpace(override.TokenURL) != "" {
|
||||
base.TokenURL = strings.TrimSpace(override.TokenURL)
|
||||
}
|
||||
if strings.TrimSpace(override.UserInfoURL) != "" {
|
||||
base.UserInfoURL = strings.TrimSpace(override.UserInfoURL)
|
||||
}
|
||||
if strings.TrimSpace(override.EmailsURL) != "" {
|
||||
base.EmailsURL = strings.TrimSpace(override.EmailsURL)
|
||||
}
|
||||
if strings.TrimSpace(override.Scopes) != "" {
|
||||
base.Scopes = strings.TrimSpace(override.Scopes)
|
||||
}
|
||||
if strings.TrimSpace(override.RedirectURL) != "" {
|
||||
base.RedirectURL = strings.TrimSpace(override.RedirectURL)
|
||||
}
|
||||
if strings.TrimSpace(override.FrontendRedirectURL) != "" {
|
||||
base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL)
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool {
|
||||
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||
return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != ""
|
||||
}
|
||||
|
||||
func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig {
|
||||
cfg := s.emailOAuthBaseConfig(provider)
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "github":
|
||||
if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok {
|
||||
cfg.Enabled = raw == "true"
|
||||
}
|
||||
cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID)
|
||||
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret)
|
||||
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL)
|
||||
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend)
|
||||
case "google":
|
||||
if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok {
|
||||
cfg.Enabled = raw == "true"
|
||||
}
|
||||
cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID)
|
||||
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret)
|
||||
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL)
|
||||
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
|
||||
// array string, returning only items with visibility != "admin".
|
||||
func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
||||
@ -1052,6 +1347,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
if settings.WeChatConnectFrontendRedirectURL == "" {
|
||||
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
|
||||
}
|
||||
settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL)
|
||||
settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL)
|
||||
if settings.GitHubOAuthFrontendRedirectURL == "" {
|
||||
settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend
|
||||
}
|
||||
settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL)
|
||||
settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL)
|
||||
if settings.GoogleOAuthFrontendRedirectURL == "" {
|
||||
settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend
|
||||
}
|
||||
|
||||
updates := make(map[string]string)
|
||||
|
||||
@ -1068,6 +1373,19 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyFrontendURL] = settings.FrontendURL
|
||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode)
|
||||
settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt)
|
||||
if settings.LoginAgreementUpdatedAt == "" {
|
||||
settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||
}
|
||||
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled)
|
||||
updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode
|
||||
updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt
|
||||
updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
@ -1121,6 +1439,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
|
||||
}
|
||||
|
||||
// GitHub / Google 邮箱快捷登录
|
||||
updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled)
|
||||
updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID)
|
||||
updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL
|
||||
updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL
|
||||
if settings.GitHubOAuthClientSecret != "" {
|
||||
updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret)
|
||||
}
|
||||
updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled)
|
||||
updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID)
|
||||
updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL
|
||||
updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL
|
||||
if settings.GoogleOAuthClientSecret != "" {
|
||||
updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret)
|
||||
}
|
||||
|
||||
// WeChat Connect OAuth 登录
|
||||
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
|
||||
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
|
||||
@ -1232,6 +1566,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
// Affiliate (邀请返利) feature switch
|
||||
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
|
||||
|
||||
// 风控中心功能开关
|
||||
updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled)
|
||||
|
||||
// Claude Code version check
|
||||
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
||||
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
|
||||
@ -1273,17 +1610,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
|
||||
settings.LinuxDo.Subscriptions,
|
||||
settings.OIDC.Subscriptions,
|
||||
settings.WeChat.Subscriptions,
|
||||
settings.GitHub.Subscriptions,
|
||||
settings.Google.Subscriptions,
|
||||
} {
|
||||
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
updates := make(map[string]string, 21)
|
||||
updates := make(map[string]string, 31)
|
||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
||||
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
||||
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
|
||||
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
|
||||
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
||||
return updates, nil
|
||||
}
|
||||
@ -1362,6 +1703,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider != "github" && provider != "google" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found")
|
||||
}
|
||||
keys := []string{
|
||||
SettingKeyGitHubOAuthEnabled,
|
||||
SettingKeyGitHubOAuthClientID,
|
||||
SettingKeyGitHubOAuthClientSecret,
|
||||
SettingKeyGitHubOAuthRedirectURL,
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL,
|
||||
SettingKeyGoogleOAuthEnabled,
|
||||
SettingKeyGoogleOAuthClientID,
|
||||
SettingKeyGoogleOAuthClientSecret,
|
||||
SettingKeyGoogleOAuthRedirectURL,
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL,
|
||||
}
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err)
|
||||
}
|
||||
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||
if !cfg.Enabled {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientID) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
for label, rawURL := range map[string]string{
|
||||
"authorize": cfg.AuthorizeURL,
|
||||
"token": cfg.TokenURL,
|
||||
"userinfo": cfg.UserInfoURL,
|
||||
"redirect": cfg.RedirectURL,
|
||||
} {
|
||||
if strings.TrimSpace(rawURL) == "" {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured")
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid")
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(cfg.EmailsURL) != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid")
|
||||
}
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
|
||||
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
@ -1534,6 +1930,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting.
|
||||
func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
|
||||
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
|
||||
@ -1711,6 +2116,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultGitHubBalance,
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultGoogleBalance,
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||
SettingKeyForceEmailOnThirdPartySignup,
|
||||
}
|
||||
|
||||
@ -1724,6 +2139,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
||||
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
|
||||
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
|
||||
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
|
||||
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
|
||||
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
|
||||
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
||||
}, nil
|
||||
}
|
||||
@ -1793,6 +2210,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
|
||||
}
|
||||
}
|
||||
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(defaultLoginAgreementDocuments())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
@ -1800,6 +2221,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeyLoginAgreementEnabled: "false",
|
||||
SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
|
||||
SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
|
||||
SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
@ -1824,6 +2249,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyWeChatConnectScopes: "snsapi_login",
|
||||
SettingKeyWeChatConnectRedirectURL: "",
|
||||
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
|
||||
SettingKeyGitHubOAuthEnabled: "false",
|
||||
SettingKeyGitHubOAuthClientID: "",
|
||||
SettingKeyGitHubOAuthClientSecret: "",
|
||||
SettingKeyGitHubOAuthRedirectURL: "",
|
||||
SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||
SettingKeyGoogleOAuthEnabled: "false",
|
||||
SettingKeyGoogleOAuthClientID: "",
|
||||
SettingKeyGoogleOAuthClientSecret: "",
|
||||
SettingKeyGoogleOAuthRedirectURL: "",
|
||||
SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||
SettingKeyOIDCConnectEnabled: "false",
|
||||
SettingKeyOIDCConnectProviderName: "OIDC",
|
||||
SettingKeyOIDCConnectClientID: "",
|
||||
@ -1874,6 +2309,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultGitHubBalance: "0",
|
||||
SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultGoogleBalance: "0",
|
||||
SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
|
||||
SettingKeyForceEmailOnThirdPartySignup: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
@ -1903,6 +2348,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// Affiliate (邀请返利) feature (default disabled; opt-in)
|
||||
SettingKeyAffiliateEnabled: "false",
|
||||
|
||||
// 风控中心功能(默认关闭,显式启用)
|
||||
SettingKeyRiskControlEnabled: "false",
|
||||
|
||||
// Claude Code version check (default: empty = disabled)
|
||||
SettingKeyMinClaudeCodeVersion: "",
|
||||
SettingKeyMaxClaudeCodeVersion: "",
|
||||
@ -1923,6 +2371,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
|
||||
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
|
||||
if loginAgreementUpdatedAt == "" {
|
||||
loginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||
}
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
@ -1932,6 +2385,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
FrontendURL: settings[SettingKeyFrontendURL],
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true",
|
||||
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
@ -2173,6 +2630,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
|
||||
|
||||
gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github")
|
||||
result.GitHubOAuthEnabled = gitHubEffective.Enabled
|
||||
result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID)
|
||||
result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret)
|
||||
result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != ""
|
||||
result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL)
|
||||
result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL)
|
||||
|
||||
googleEffective := s.effectiveEmailOAuthConfig(settings, "google")
|
||||
result.GoogleOAuthEnabled = googleEffective.Enabled
|
||||
result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID)
|
||||
result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret)
|
||||
result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != ""
|
||||
result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL)
|
||||
result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL)
|
||||
|
||||
// WeChat Connect 设置:
|
||||
// - 优先读取 DB 系统设置
|
||||
// - 缺失时回退到 config/env,保持升级兼容
|
||||
@ -2242,6 +2715,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
// Affiliate (邀请返利) feature (default: disabled; strict true)
|
||||
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
|
||||
|
||||
// 风控中心功能(默认关闭,严格 true 才启用)
|
||||
result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true"
|
||||
|
||||
// Claude Code version check
|
||||
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
|
||||
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
|
||||
|
||||
@ -20,6 +20,10 @@ type SystemSettings struct {
|
||||
FrontendURL string
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
LoginAgreementEnabled bool
|
||||
LoginAgreementMode string
|
||||
LoginAgreementUpdatedAt string
|
||||
LoginAgreementDocuments []LoginAgreementDocument
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
@ -89,6 +93,20 @@ type SystemSettings struct {
|
||||
OIDCConnectUserInfoIDPath string
|
||||
OIDCConnectUserInfoUsernamePath string
|
||||
|
||||
// GitHub / Google 邮箱快捷登录
|
||||
GitHubOAuthEnabled bool
|
||||
GitHubOAuthClientID string
|
||||
GitHubOAuthClientSecret string
|
||||
GitHubOAuthClientSecretConfigured bool
|
||||
GitHubOAuthRedirectURL string
|
||||
GitHubOAuthFrontendRedirectURL string
|
||||
GoogleOAuthEnabled bool
|
||||
GoogleOAuthClientID string
|
||||
GoogleOAuthClientSecret string
|
||||
GoogleOAuthClientSecretConfigured bool
|
||||
GoogleOAuthRedirectURL string
|
||||
GoogleOAuthFrontendRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
@ -106,6 +124,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
RiskControlEnabled bool
|
||||
AffiliateEnabled bool
|
||||
AffiliateRebateRate float64
|
||||
AffiliateRebateFreezeHours int
|
||||
@ -190,6 +209,11 @@ type PublicSettings struct {
|
||||
PasswordResetEnabled bool
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
LoginAgreementEnabled bool
|
||||
LoginAgreementMode string
|
||||
LoginAgreementUpdatedAt string
|
||||
LoginAgreementRevision string
|
||||
LoginAgreementDocuments []LoginAgreementDocument
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
@ -217,6 +241,8 @@ type PublicSettings struct {
|
||||
PaymentEnabled bool
|
||||
OIDCOAuthEnabled bool
|
||||
OIDCOAuthProviderName string
|
||||
GitHubOAuthEnabled bool
|
||||
GoogleOAuthEnabled bool
|
||||
Version string
|
||||
|
||||
BalanceLowNotifyEnabled bool
|
||||
@ -233,6 +259,15 @@ type PublicSettings struct {
|
||||
|
||||
// Affiliate (邀请返利) feature toggle
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
// 风控中心功能开关
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
}
|
||||
|
||||
type LoginAgreementDocument struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
ContentMD string `json:"content_md"`
|
||||
}
|
||||
|
||||
type WeChatConnectOAuthConfig struct {
|
||||
|
||||
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -95,6 +96,9 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return account.IsRateLimited()
|
||||
|
||||
@ -97,6 +97,8 @@ type UserRepository interface {
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error)
|
||||
BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error)
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
|
||||
|
||||
@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
|
||||
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||
out := make([]UserAuthIdentityRecord, len(m.identities))
|
||||
|
||||
@ -515,6 +515,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
NewContentModerationService,
|
||||
NewAffiliateService,
|
||||
ProvidePaymentConfigService,
|
||||
NewPaymentService,
|
||||
|
||||
27
backend/migrations/135_allow_email_oauth_provider_types.sql
Normal file
27
backend/migrations/135_allow_email_oauth_provider_types.sql
Normal file
@ -0,0 +1,27 @@
|
||||
ALTER TABLE users
|
||||
DROP CONSTRAINT IF EXISTS users_signup_source_check;
|
||||
|
||||
ALTER TABLE users
|
||||
ADD CONSTRAINT users_signup_source_check
|
||||
CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||
|
||||
ALTER TABLE auth_identities
|
||||
DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check;
|
||||
|
||||
ALTER TABLE auth_identities
|
||||
ADD CONSTRAINT auth_identities_provider_type_check
|
||||
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||
|
||||
ALTER TABLE auth_identity_channels
|
||||
DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check;
|
||||
|
||||
ALTER TABLE auth_identity_channels
|
||||
ADD CONSTRAINT auth_identity_channels_provider_type_check
|
||||
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||
|
||||
ALTER TABLE pending_auth_sessions
|
||||
DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check;
|
||||
|
||||
ALTER TABLE pending_auth_sessions
|
||||
ADD CONSTRAINT pending_auth_sessions_provider_type_check
|
||||
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||
45
backend/migrations/135_content_moderation.sql
Normal file
45
backend/migrations/135_content_moderation.sql
Normal file
@ -0,0 +1,45 @@
|
||||
-- 风控中心内容审计配置与记录
|
||||
|
||||
INSERT INTO settings (key, value, updated_at)
|
||||
VALUES ('risk_control_enabled', 'false', NOW())
|
||||
ON CONFLICT (key) DO NOTHING;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS content_moderation_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
request_id VARCHAR(128) NOT NULL DEFAULT '',
|
||||
user_id BIGINT REFERENCES users(id) ON DELETE SET NULL,
|
||||
user_email VARCHAR(255) NOT NULL DEFAULT '',
|
||||
api_key_id BIGINT REFERENCES api_keys(id) ON DELETE SET NULL,
|
||||
api_key_name VARCHAR(100) NOT NULL DEFAULT '',
|
||||
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
|
||||
group_name VARCHAR(255) NOT NULL DEFAULT '',
|
||||
endpoint VARCHAR(128) NOT NULL DEFAULT '',
|
||||
provider VARCHAR(64) NOT NULL DEFAULT '',
|
||||
model VARCHAR(255) NOT NULL DEFAULT '',
|
||||
mode VARCHAR(32) NOT NULL DEFAULT '',
|
||||
action VARCHAR(32) NOT NULL DEFAULT '',
|
||||
flagged BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
highest_category VARCHAR(64) NOT NULL DEFAULT '',
|
||||
highest_score DECIMAL(8, 6) NOT NULL DEFAULT 0,
|
||||
category_scores JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
threshold_snapshot JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
input_excerpt TEXT NOT NULL DEFAULT '',
|
||||
upstream_latency_ms INT,
|
||||
error TEXT NOT NULL DEFAULT '',
|
||||
violation_count INT NOT NULL DEFAULT 0,
|
||||
auto_banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
email_sent BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
queue_delay_ms INT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS violation_count INT NOT NULL DEFAULT 0;
|
||||
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS auto_banned BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS email_sent BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS queue_delay_ms INT;
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_created_at ON content_moderation_logs(created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_group_created_at ON content_moderation_logs(group_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_flagged_created_at ON content_moderation_logs(flagged, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_user_created_at ON content_moderation_logs(user_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_api_key_created_at ON content_moderation_logs(api_key_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_endpoint_created_at ON content_moderation_logs(endpoint, created_at DESC);
|
||||
@ -142,3 +142,16 @@ func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T)
|
||||
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
|
||||
require.NotContains(t, sql, "detail::jsonb")
|
||||
}
|
||||
|
||||
func TestMigration135AllowsGitHubAndGoogleAuthProviders(t *testing.T) {
|
||||
content, err := FS.ReadFile("135_allow_email_oauth_provider_types.sql")
|
||||
require.NoError(t, err)
|
||||
|
||||
sql := string(content)
|
||||
require.Contains(t, sql, "users_signup_source_check")
|
||||
require.Contains(t, sql, "auth_identities_provider_type_check")
|
||||
require.Contains(t, sql, "auth_identity_channels_provider_type_check")
|
||||
require.Contains(t, sql, "pending_auth_sessions_provider_type_check")
|
||||
require.Contains(t, sql, "'github'")
|
||||
require.Contains(t, sql, "'google'")
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.2-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.3-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.20
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@ -202,6 +202,14 @@ gateway:
|
||||
#
|
||||
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
|
||||
force_codex_cli: false
|
||||
# Enable Codex image-generation bridge injection for /openai/v1/responses.
|
||||
# 是否为 Codex /responses 请求自动注入 image_generation 工具与桥接指令。
|
||||
#
|
||||
# Default false keeps text-only Codex requests text-only. Explicit client-provided
|
||||
# image_generation tools are still forwarded when the group allows image generation.
|
||||
# 默认 false:保持纯文本 Codex 请求不被改写;客户端显式提供 image_generation tool 时,
|
||||
# 仍会在分组允许图片生成的情况下正常转发。
|
||||
codex_image_generation_bridge_enabled: false
|
||||
# Optional: template file used to build the final top-level Codex `instructions`.
|
||||
# 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。
|
||||
#
|
||||
|
||||
@ -16,6 +16,8 @@ import type {
|
||||
TempUnschedulableStatus,
|
||||
AdminDataPayload,
|
||||
AdminDataImportResult,
|
||||
CodexSessionImportRequest,
|
||||
CodexSessionImportResult,
|
||||
CheckMixedChannelRequest,
|
||||
CheckMixedChannelResponse
|
||||
} from '@/types'
|
||||
@ -547,6 +549,11 @@ export async function importData(payload: {
|
||||
return data
|
||||
}
|
||||
|
||||
export async function importCodexSession(payload: CodexSessionImportRequest): Promise<CodexSessionImportResult> {
|
||||
const { data } = await apiClient.post<CodexSessionImportResult>('/admin/accounts/import/codex-session', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Antigravity default model mapping from backend
|
||||
* @returns Default model mapping (from -> to)
|
||||
@ -663,6 +670,7 @@ export const accountsAPI = {
|
||||
syncFromCrs,
|
||||
exportData,
|
||||
importData,
|
||||
importCodexSession,
|
||||
getAntigravityDefaultModelMapping,
|
||||
batchClearError,
|
||||
batchRefresh,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user