chore: merge upstream v0.1.124-125, keep Windsurf/Antigravity customizations
Upstream changes: - feat: 邮箱 + GitHub + Google OAuth 快捷登录 - feat: Codex image bridge 开关 - feat: 内容审核 (content moderation) — 新增 contentModerationService/Handler - feat: redeem code 返利、批量并发 API、markdown 页面渲染 - feat: 登录注册条款确认 - fix(security): pages API 加 JWT + 可见性校验 - fix: 修复 markdown 页面图片路径 - fix(gateway): 不再默认注入 redact thinking beta - fix: 稳定 anthropic passthrough 超时错误 - chore: VERSION 升到 0.1.125 + golang:1.26.3-alpine Conflict resolutions: - Dockerfile/backend/Dockerfile: 取 upstream golang:1.26.3-alpine - backend/go.mod: 取 upstream term v0.42.0,保留定制 protobuf v1.36.10 - frontend/src/api/admin/index.ts: 并集 (windsurf + riskControl) - backend/cmd/server/wire_gen.go: 接 upstream contentModeration*,保留 windsurfHandler/windsurfGatewayService/billingCacheService/requestEventBus;并通过 wire 重生成 - frontend/src/views/admin/AccountsView.vue: 采用 upstream 双层布局 + OpenAI Meta,保留 is_enterprise prop 和 Windsurf tier badge Note: - WIP commit (de048fad) preserved Windsurf tier access service / NLU extractor / ops log stream / Google OAuth login modal et al before merge. - 3 pre-existing go vet issues in test files (NewOpsHandler, RegisterGatewayRoutes, DefaultCLIProductVersion) are unrelated to this merge — leftover from local customization refactors; production code (go build ./...) passes.
This commit is contained in:
commit
7347dfffc1
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
cache-dependency-path: backend/go.sum
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.3'
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: make test-unit
|
run: make test-unit
|
||||||
@ -60,7 +60,7 @@ jobs:
|
|||||||
cache-dependency-path: backend/go.sum
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.3'
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v9
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -115,7 +115,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.3'
|
||||||
|
|
||||||
# Docker setup for GoReleaser
|
# Docker setup for GoReleaser
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
|
|||||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
|||||||
cache-dependency-path: backend/go.sum
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.3'
|
||||||
- name: Run govulncheck
|
- name: Run govulncheck
|
||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26-alpine
|
ARG GOLANG_IMAGE=golang:1.26.3-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.21
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
|
|||||||
@ -72,8 +72,8 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||||
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
|
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
|
|||||||
@ -71,8 +71,8 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
|||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||||
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
|
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/register?aff=SUB2API">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
|
|||||||
@ -71,8 +71,8 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
|||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
<td width="180"><a href="https://code.silkapi.com/register?aff=SUB2API"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
|
||||||
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
|
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/register?aff=SUB2API">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.26-alpine
|
FROM golang:1.26.3-alpine
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
0.1.123
|
0.1.125
|
||||||
|
|||||||
@ -8,11 +8,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/ent"
|
"github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
@ -23,9 +18,14 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
|
||||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -136,26 +136,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||||
|
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||||
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider)
|
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
|
||||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||||
windsurfLSService := service.ProvideWindsurfLSService(configConfig)
|
windsurfLSService := service.ProvideWindsurfLSService(configConfig)
|
||||||
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
|
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
|
||||||
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
|
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
|
||||||
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
|
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||||
@ -237,19 +236,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
||||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||||
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
|
||||||
|
contentModerationRepository := repository.NewContentModerationRepository(db)
|
||||||
|
contentModerationHashCache := repository.NewContentModerationHashCache(redisClient)
|
||||||
|
contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService)
|
||||||
|
contentModerationHandler := admin.NewContentModerationHandler(contentModerationService)
|
||||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||||
windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService)
|
windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService)
|
||||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
|
||||||
windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository)
|
windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository)
|
||||||
windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository)
|
windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository)
|
||||||
windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService)
|
windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService)
|
||||||
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, windsurfHandler, affiliateHandler)
|
||||||
|
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
||||||
@ -274,6 +277,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||||
|
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
@ -485,6 +489,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"WindsurfLSService", func() error {
|
||||||
|
if windsurfLS != nil {
|
||||||
|
windsurfLS.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import (
|
|||||||
|
|
||||||
var authProviderTypes = map[string]struct{}{
|
var authProviderTypes = map[string]struct{}{
|
||||||
"email": {},
|
"email": {},
|
||||||
|
"github": {},
|
||||||
|
"google": {},
|
||||||
"linuxdo": {},
|
"linuxdo": {},
|
||||||
"oidc": {},
|
"oidc": {},
|
||||||
"wechat": {},
|
"wechat": {},
|
||||||
|
|||||||
@ -83,10 +83,10 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
|
|||||||
require.Equal(t, 1, signupSource.Validators)
|
require.Equal(t, 1, signupSource.Validators)
|
||||||
|
|
||||||
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
||||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
|
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
|
||||||
require.NoError(t, validator(value))
|
require.NoError(t, validator(value))
|
||||||
}
|
}
|
||||||
require.Error(t, validator("github"))
|
require.Error(t, validator("unknown"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
|
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
|
||||||
|
|||||||
@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
|
|||||||
field.String("signup_source").
|
field.String("signup_source").
|
||||||
Validate(func(value string) error {
|
Validate(func(value string) error {
|
||||||
switch value {
|
switch value {
|
||||||
case "email", "linuxdo", "wechat", "oidc":
|
case "email", "linuxdo", "wechat", "oidc", "github", "google":
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
|
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
|
||||||
}
|
}
|
||||||
}).
|
}).
|
||||||
Default("email"),
|
Default("email"),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
module github.com/Wei-Shaw/sub2api
|
module github.com/Wei-Shaw/sub2api
|
||||||
|
|
||||||
go 1.26.2
|
go 1.26.3
|
||||||
|
|
||||||
require (
|
require (
|
||||||
connectrpc.com/connect v1.19.2
|
connectrpc.com/connect v1.19.2
|
||||||
@ -21,6 +21,7 @@ require (
|
|||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/imroc/req/v3 v3.57.0
|
github.com/imroc/req/v3 v3.57.0
|
||||||
|
github.com/klauspost/compress v1.18.2
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/pquerna/otp v1.5.0
|
github.com/pquerna/otp v1.5.0
|
||||||
@ -40,11 +41,11 @@ require (
|
|||||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
|
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
|
||||||
github.com/zeromicro/go-zero v1.9.4
|
github.com/zeromicro/go-zero v1.9.4
|
||||||
go.uber.org/zap v1.24.0
|
go.uber.org/zap v1.24.0
|
||||||
golang.org/x/crypto v0.49.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/image v0.39.0
|
golang.org/x/image v0.39.0
|
||||||
golang.org/x/net v0.52.0
|
golang.org/x/net v0.53.0
|
||||||
golang.org/x/sync v0.20.0
|
golang.org/x/sync v0.20.0
|
||||||
golang.org/x/term v0.41.0
|
golang.org/x/term v0.42.0
|
||||||
google.golang.org/protobuf v1.36.10
|
google.golang.org/protobuf v1.36.10
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
@ -112,7 +113,6 @@ require (
|
|||||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||||
github.com/icholy/digest v1.1.0 // indirect
|
github.com/icholy/digest v1.1.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/compress v1.18.2 // indirect
|
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
@ -176,7 +176,7 @@ require (
|
|||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/mod v0.34.0 // indirect
|
golang.org/x/mod v0.34.0 // indirect
|
||||||
golang.org/x/sys v0.42.0 // indirect
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
golang.org/x/text v0.36.0 // indirect
|
golang.org/x/text v0.36.0 // indirect
|
||||||
golang.org/x/tools v0.43.0 // indirect
|
golang.org/x/tools v0.43.0 // indirect
|
||||||
google.golang.org/grpc v1.75.1 // indirect
|
google.golang.org/grpc v1.75.1 // indirect
|
||||||
|
|||||||
@ -187,8 +187,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
|||||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@ -224,8 +222,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@ -259,8 +255,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@ -290,8 +284,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
|||||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
|
||||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
@ -324,8 +316,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
@ -417,16 +407,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
|
|||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
|
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
|
||||||
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
|
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
|
||||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@ -438,10 +428,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
|
|||||||
@ -79,6 +79,8 @@ type Config struct {
|
|||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
||||||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||||||
|
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
|
||||||
|
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
@ -248,6 +250,19 @@ type OIDCConnectConfig struct {
|
|||||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmailOAuthProviderConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
ClientID string `mapstructure:"client_id"`
|
||||||
|
ClientSecret string `mapstructure:"client_secret"`
|
||||||
|
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||||
|
TokenURL string `mapstructure:"token_url"`
|
||||||
|
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||||
|
EmailsURL string `mapstructure:"emails_url"`
|
||||||
|
Scopes string `mapstructure:"scopes"`
|
||||||
|
RedirectURL string `mapstructure:"redirect_url"`
|
||||||
|
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultWeChatConnectMode = "open"
|
defaultWeChatConnectMode = "open"
|
||||||
defaultWeChatConnectScopes = "snsapi_login"
|
defaultWeChatConnectScopes = "snsapi_login"
|
||||||
@ -619,6 +634,9 @@ type GatewayConfig struct {
|
|||||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||||
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
|
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
|
||||||
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
|
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
|
||||||
|
// CodexImageGenerationBridgeEnabled: 是否为 Codex `/v1/responses` 自动注入 image_generation 工具和桥接指令。
|
||||||
|
// 默认关闭,避免纯文本 Codex 请求被意外改写;显式携带 image_generation 工具的请求仍按分组能力转发。
|
||||||
|
CodexImageGenerationBridgeEnabled bool `mapstructure:"codex_image_generation_bridge_enabled"`
|
||||||
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
|
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
|
||||||
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
|
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
|
||||||
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
|
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
|
||||||
@ -1773,6 +1791,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.max_account_switches", 10)
|
viper.SetDefault("gateway.max_account_switches", 10)
|
||||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||||
viper.SetDefault("gateway.force_codex_cli", false)
|
viper.SetDefault("gateway.force_codex_cli", false)
|
||||||
|
viper.SetDefault("gateway.codex_image_generation_bridge_enabled", false)
|
||||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||||
|
|||||||
1045
backend/internal/handler/admin/account_codex_import.go
Normal file
1045
backend/internal/handler/admin/account_codex_import.go
Normal file
File diff suppressed because it is too large
Load Diff
344
backend/internal/handler/admin/account_codex_import_test.go
Normal file
344
backend/internal/handler/admin/account_codex_import_test.go
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseCodexSessionImportEntriesSupportsRawTokenJSONAndArray(t *testing.T) {
|
||||||
|
token1 := "raw-access-token-1"
|
||||||
|
token2 := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
|
||||||
|
"email": "json@example.com",
|
||||||
|
})
|
||||||
|
token3 := "raw-access-token-3"
|
||||||
|
|
||||||
|
req := CodexSessionImportRequest{
|
||||||
|
Content: fmt.Sprintf("%s\n{\"accessToken\":%q}\n[%q]", token1, token2, token3),
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := parseCodexSessionImportEntries(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 3 {
|
||||||
|
t.Fatalf("len(entries) = %d, want 3", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := normalizeCodexImportEntry(entries[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize raw token error = %v", err)
|
||||||
|
}
|
||||||
|
if first.Credentials["access_token"] != token1 {
|
||||||
|
t.Fatalf("raw token access_token = %v, want %s", first.Credentials["access_token"], token1)
|
||||||
|
}
|
||||||
|
|
||||||
|
second, err := normalizeCodexImportEntry(entries[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize json token error = %v", err)
|
||||||
|
}
|
||||||
|
if second.Email != "json@example.com" {
|
||||||
|
t.Fatalf("email = %q, want json@example.com", second.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
third, err := normalizeCodexImportEntry(entries[2])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize array token error = %v", err)
|
||||||
|
}
|
||||||
|
if third.Credentials["access_token"] != token3 {
|
||||||
|
t.Fatalf("array token access_token = %v, want %s", third.Credentials["access_token"], token3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexSessionImportEntriesFallsBackToLineModeForMixedJSONAndToken(t *testing.T) {
|
||||||
|
req := CodexSessionImportRequest{
|
||||||
|
Content: "{\"accessToken\":\"json-line-token\"}\nraw-line-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := parseCodexSessionImportEntries(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseCodexSessionImportEntries error = %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("len(entries) = %d, want 2", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := normalizeCodexImportEntry(entries[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize json line error = %v", err)
|
||||||
|
}
|
||||||
|
if first.Credentials["access_token"] != "json-line-token" {
|
||||||
|
t.Fatalf("json line access_token = %v, want json-line-token", first.Credentials["access_token"])
|
||||||
|
}
|
||||||
|
|
||||||
|
second, err := normalizeCodexImportEntry(entries[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize raw line error = %v", err)
|
||||||
|
}
|
||||||
|
if second.Credentials["access_token"] != "raw-line-token" {
|
||||||
|
t.Fatalf("raw line access_token = %v, want raw-line-token", second.Credentials["access_token"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCodexSessionJSONExtractsCredentialsAndIgnoresSessionToken(t *testing.T) {
|
||||||
|
accessToken := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{
|
||||||
|
"email": "claim@example.com",
|
||||||
|
"https://api.openai.com/auth": map[string]any{
|
||||||
|
"chatgpt_account_id": "acct-from-claim",
|
||||||
|
"chatgpt_user_id": "user-from-claim",
|
||||||
|
"chatgpt_plan_type": "plus",
|
||||||
|
"poid": "org-from-claim",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
raw := map[string]any{
|
||||||
|
"user": map[string]any{
|
||||||
|
"id": "user-from-json",
|
||||||
|
"name": "Sup OO",
|
||||||
|
"email": "json@example.com",
|
||||||
|
"image": "https://example.com/avatar.png",
|
||||||
|
},
|
||||||
|
"account": map[string]any{
|
||||||
|
"id": "acct-from-json",
|
||||||
|
"planType": "free",
|
||||||
|
},
|
||||||
|
"accessToken": accessToken,
|
||||||
|
"sessionToken": "secret-session-token",
|
||||||
|
"expires": "2026-08-05T13:40:42.836Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
item, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: raw})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeCodexImportEntry error = %v", err)
|
||||||
|
}
|
||||||
|
if item.Credentials["access_token"] != accessToken {
|
||||||
|
t.Fatalf("access_token not stored")
|
||||||
|
}
|
||||||
|
if item.Credentials["email"] != "json@example.com" {
|
||||||
|
t.Fatalf("email = %v, want json@example.com", item.Credentials["email"])
|
||||||
|
}
|
||||||
|
if item.Credentials["chatgpt_account_id"] != "acct-from-json" {
|
||||||
|
t.Fatalf("chatgpt_account_id = %v, want acct-from-json", item.Credentials["chatgpt_account_id"])
|
||||||
|
}
|
||||||
|
if item.Credentials["chatgpt_user_id"] != "user-from-json" {
|
||||||
|
t.Fatalf("chatgpt_user_id = %v, want user-from-json", item.Credentials["chatgpt_user_id"])
|
||||||
|
}
|
||||||
|
if item.Credentials["plan_type"] != "free" {
|
||||||
|
t.Fatalf("plan_type = %v, want free", item.Credentials["plan_type"])
|
||||||
|
}
|
||||||
|
if _, ok := item.Credentials["session_token"]; ok {
|
||||||
|
t.Fatalf("session_token should not be written to credentials")
|
||||||
|
}
|
||||||
|
if item.Extra["session_token_present"] != true {
|
||||||
|
t.Fatalf("session_token_present = %v, want true", item.Extra["session_token_present"])
|
||||||
|
}
|
||||||
|
if item.Extra["session_expires_at"] != "2026-08-05T13:40:42Z" {
|
||||||
|
t.Fatalf("session_expires_at = %v", item.Extra["session_expires_at"])
|
||||||
|
}
|
||||||
|
if item.TokenExpiresAt == nil {
|
||||||
|
t.Fatalf("TokenExpiresAt should be parsed from accessToken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCodexImportCredentialsClearsStaleRefreshFieldsWhenIncomingHasNoRefreshToken(t *testing.T) {
|
||||||
|
existing := map[string]any{
|
||||||
|
"access_token": "old-access-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"client_id": "old-client-id",
|
||||||
|
"id_token": "old-id-token",
|
||||||
|
"model_mapping": map[string]any{"from": "existing"},
|
||||||
|
"chatgpt_account_id": "acct-old",
|
||||||
|
"unrelated_existing": "keep",
|
||||||
|
}
|
||||||
|
incoming := map[string]any{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"expires_at": "2026-08-05T13:40:42Z",
|
||||||
|
"chatgpt_account_id": "acct-new",
|
||||||
|
}
|
||||||
|
item := &codexImportAccount{
|
||||||
|
AccessToken: "new-access-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := mergeCodexImportCredentials(existing, incoming, item)
|
||||||
|
|
||||||
|
if merged["access_token"] != "new-access-token" {
|
||||||
|
t.Fatalf("access_token = %v, want new-access-token", merged["access_token"])
|
||||||
|
}
|
||||||
|
if merged["chatgpt_account_id"] != "acct-new" {
|
||||||
|
t.Fatalf("chatgpt_account_id = %v, want acct-new", merged["chatgpt_account_id"])
|
||||||
|
}
|
||||||
|
if _, ok := merged["refresh_token"]; ok {
|
||||||
|
t.Fatalf("refresh_token should be cleared")
|
||||||
|
}
|
||||||
|
if _, ok := merged["client_id"]; ok {
|
||||||
|
t.Fatalf("client_id should be cleared")
|
||||||
|
}
|
||||||
|
if _, ok := merged["id_token"]; ok {
|
||||||
|
t.Fatalf("id_token should be cleared")
|
||||||
|
}
|
||||||
|
if merged["unrelated_existing"] != "keep" {
|
||||||
|
t.Fatalf("unrelated_existing = %v, want keep", merged["unrelated_existing"])
|
||||||
|
}
|
||||||
|
if _, ok := merged["model_mapping"]; !ok {
|
||||||
|
t.Fatalf("model_mapping should be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCodexImportCredentialsKeepsRefreshFieldsWhenIncomingHasRefreshToken(t *testing.T) {
|
||||||
|
existing := map[string]any{
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"client_id": "old-client-id",
|
||||||
|
"id_token": "old-id-token",
|
||||||
|
}
|
||||||
|
incoming := map[string]any{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
"client_id": "new-client-id",
|
||||||
|
"id_token": "new-id-token",
|
||||||
|
}
|
||||||
|
item := &codexImportAccount{
|
||||||
|
AccessToken: "new-access-token",
|
||||||
|
RefreshToken: "new-refresh-token",
|
||||||
|
IDToken: "new-id-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := mergeCodexImportCredentials(existing, incoming, item)
|
||||||
|
|
||||||
|
if merged["refresh_token"] != "new-refresh-token" {
|
||||||
|
t.Fatalf("refresh_token = %v, want new-refresh-token", merged["refresh_token"])
|
||||||
|
}
|
||||||
|
if merged["client_id"] != "new-client-id" {
|
||||||
|
t.Fatalf("client_id = %v, want new-client-id", merged["client_id"])
|
||||||
|
}
|
||||||
|
if merged["id_token"] != "new-id-token" {
|
||||||
|
t.Fatalf("id_token = %v, want new-id-token", merged["id_token"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCodexImportRejectsExpiredAccessToken(t *testing.T) {
|
||||||
|
expiredToken := buildCodexImportTestJWT(t, time.Now().Add(-time.Hour), map[string]any{})
|
||||||
|
|
||||||
|
_, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: expiredToken})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("normalizeCodexImportEntry error = nil, want expired token error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "已过期") {
|
||||||
|
t.Fatalf("error = %v, want expired token message", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCodexImportExpiryForNoRefreshTokenUsesTokenExpiry(t *testing.T) {
|
||||||
|
tokenExpiresAt := time.Now().Add(time.Hour).UTC()
|
||||||
|
item := &codexImportAccount{
|
||||||
|
AccessToken: "access-token",
|
||||||
|
Credentials: map[string]any{"access_token": "access-token"},
|
||||||
|
TokenExpiresAt: &tokenExpiresAt,
|
||||||
|
WarningTexts: []string{},
|
||||||
|
}
|
||||||
|
disabled := false
|
||||||
|
req := CodexSessionImportRequest{AutoPauseOnExpired: &disabled}
|
||||||
|
|
||||||
|
accountExpiresAt, credentialExpiresAt, autoPause, warnings, err := resolveCodexImportExpiry(req, item)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveCodexImportExpiry error = %v", err)
|
||||||
|
}
|
||||||
|
if accountExpiresAt == nil || *accountExpiresAt != tokenExpiresAt.Unix() {
|
||||||
|
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, tokenExpiresAt.Unix())
|
||||||
|
}
|
||||||
|
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != tokenExpiresAt.Unix() {
|
||||||
|
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, tokenExpiresAt)
|
||||||
|
}
|
||||||
|
if autoPause == nil || !*autoPause {
|
||||||
|
t.Fatalf("autoPause = %v, want true", autoPause)
|
||||||
|
}
|
||||||
|
if len(warnings) == 0 {
|
||||||
|
t.Fatalf("warnings should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCodexImportExpiryForNoRefreshTokenRequiresExpiry(t *testing.T) {
|
||||||
|
item := &codexImportAccount{
|
||||||
|
AccessToken: "opaque-access-token",
|
||||||
|
Credentials: map[string]any{"access_token": "opaque-access-token"},
|
||||||
|
WarningTexts: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, _, err := resolveCodexImportExpiry(CodexSessionImportRequest{}, item)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("resolveCodexImportExpiry error = nil, want missing expiry error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "无法解析 accessToken 过期时间") {
|
||||||
|
t.Fatalf("error = %v, want missing expiry message", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCodexImportExpiryForNoRefreshTokenUsesEarlierRequestExpiry(t *testing.T) {
|
||||||
|
tokenExpiresAt := time.Now().Add(2 * time.Hour).UTC()
|
||||||
|
requestExpiresAt := time.Now().Add(time.Hour).UTC()
|
||||||
|
item := &codexImportAccount{
|
||||||
|
AccessToken: "access-token",
|
||||||
|
Credentials: map[string]any{"access_token": "access-token"},
|
||||||
|
TokenExpiresAt: &tokenExpiresAt,
|
||||||
|
WarningTexts: []string{},
|
||||||
|
}
|
||||||
|
reqUnix := requestExpiresAt.Unix()
|
||||||
|
req := CodexSessionImportRequest{ExpiresAt: &reqUnix}
|
||||||
|
|
||||||
|
accountExpiresAt, credentialExpiresAt, _, _, err := resolveCodexImportExpiry(req, item)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveCodexImportExpiry error = %v", err)
|
||||||
|
}
|
||||||
|
if accountExpiresAt == nil || *accountExpiresAt != requestExpiresAt.Unix() {
|
||||||
|
t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, requestExpiresAt.Unix())
|
||||||
|
}
|
||||||
|
if credentialExpiresAt == nil || credentialExpiresAt.Unix() != requestExpiresAt.Unix() {
|
||||||
|
t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, requestExpiresAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexIdentityKeysPreferStrongIdentifiers(t *testing.T) {
|
||||||
|
keys := buildCodexIdentityKeys("acct-1", "user-1", "same@example.com", "token")
|
||||||
|
for _, key := range keys {
|
||||||
|
if strings.HasPrefix(key, "email:") {
|
||||||
|
t.Fatalf("strong identity should not include email fallback: %v", keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keys = buildCodexIdentityKeys("", "", "same@example.com", "token")
|
||||||
|
hasEmail := false
|
||||||
|
for _, key := range keys {
|
||||||
|
if key == "email:same@example.com" {
|
||||||
|
hasEmail = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasEmail {
|
||||||
|
t.Fatalf("weak identity should include email fallback: %v", keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCodexImportTestJWT(t *testing.T, exp time.Time, extraClaims map[string]any) string {
|
||||||
|
t.Helper()
|
||||||
|
header := map[string]any{
|
||||||
|
"alg": "none",
|
||||||
|
"typ": "JWT",
|
||||||
|
}
|
||||||
|
claims := map[string]any{
|
||||||
|
"sub": "user-from-sub",
|
||||||
|
"exp": exp.Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
for k, v := range extraClaims {
|
||||||
|
claims[k] = v
|
||||||
|
}
|
||||||
|
headerBytes, err := json.Marshal(header)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal header: %v", err)
|
||||||
|
}
|
||||||
|
claimBytes, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal claims: %v", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(headerBytes) + "." + base64.RawURLEncoding.EncodeToString(claimBytes) + "."
|
||||||
|
}
|
||||||
@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||||
|
return len(userIDs), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
|
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
|
||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
238
backend/internal/handler/admin/content_moderation_handler.go
Normal file
238
backend/internal/handler/admin/content_moderation_handler.go
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ContentModerationHandler struct {
|
||||||
|
service *service.ContentModerationService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler {
|
||||||
|
return &ContentModerationHandler{service: svc}
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentModerationConfigRequest struct {
|
||||||
|
Enabled *bool `json:"enabled"`
|
||||||
|
Mode *string `json:"mode"`
|
||||||
|
BaseURL *string `json:"base_url"`
|
||||||
|
Model *string `json:"model"`
|
||||||
|
APIKey *string `json:"api_key"`
|
||||||
|
APIKeys *[]string `json:"api_keys"`
|
||||||
|
APIKeysMode string `json:"api_keys_mode"`
|
||||||
|
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||||
|
ClearAPIKey bool `json:"clear_api_key"`
|
||||||
|
TimeoutMS *int `json:"timeout_ms"`
|
||||||
|
SampleRate *int `json:"sample_rate"`
|
||||||
|
AllGroups *bool `json:"all_groups"`
|
||||||
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
|
RecordNonHits *bool `json:"record_non_hits"`
|
||||||
|
WorkerCount *int `json:"worker_count"`
|
||||||
|
QueueSize *int `json:"queue_size"`
|
||||||
|
BlockStatus *int `json:"block_status"`
|
||||||
|
BlockMessage *string `json:"block_message"`
|
||||||
|
EmailOnHit *bool `json:"email_on_hit"`
|
||||||
|
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||||
|
BanThreshold *int `json:"ban_threshold"`
|
||||||
|
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||||
|
RetryCount *int `json:"retry_count"`
|
||||||
|
HitRetentionDays *int `json:"hit_retention_days"`
|
||||||
|
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||||
|
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentModerationAPIKeyTestRequest struct {
|
||||||
|
APIKeys []string `json:"api_keys"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
TimeoutMS int `json:"timeout_ms"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Images []string `json:"images"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentModerationHashRequest struct {
|
||||||
|
InputHash string `json:"input_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) GetConfig(c *gin.Context) {
|
||||||
|
cfg, err := h.service.GetConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||||
|
var req contentModerationConfigRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Mode: req.Mode,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
Model: req.Model,
|
||||||
|
APIKey: req.APIKey,
|
||||||
|
APIKeys: req.APIKeys,
|
||||||
|
APIKeysMode: req.APIKeysMode,
|
||||||
|
DeleteAPIKeyHashes: req.DeleteAPIKeyHashes,
|
||||||
|
ClearAPIKey: req.ClearAPIKey,
|
||||||
|
TimeoutMS: req.TimeoutMS,
|
||||||
|
SampleRate: req.SampleRate,
|
||||||
|
AllGroups: req.AllGroups,
|
||||||
|
GroupIDs: req.GroupIDs,
|
||||||
|
RecordNonHits: req.RecordNonHits,
|
||||||
|
WorkerCount: req.WorkerCount,
|
||||||
|
QueueSize: req.QueueSize,
|
||||||
|
BlockStatus: req.BlockStatus,
|
||||||
|
BlockMessage: req.BlockMessage,
|
||||||
|
EmailOnHit: req.EmailOnHit,
|
||||||
|
AutoBanEnabled: req.AutoBanEnabled,
|
||||||
|
BanThreshold: req.BanThreshold,
|
||||||
|
ViolationWindowHours: req.ViolationWindowHours,
|
||||||
|
RetryCount: req.RetryCount,
|
||||||
|
HitRetentionDays: req.HitRetentionDays,
|
||||||
|
NonHitRetentionDays: req.NonHitRetentionDays,
|
||||||
|
PreHashCheckEnabled: req.PreHashCheckEnabled,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) {
|
||||||
|
var req contentModerationAPIKeyTestRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{
|
||||||
|
APIKeys: req.APIKeys,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
Model: req.Model,
|
||||||
|
TimeoutMS: req.TimeoutMS,
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Images: req.Images,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) GetStatus(c *gin.Context) {
|
||||||
|
status, err := h.service.GetStatus(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) ListLogs(c *gin.Context) {
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
filter := service.ContentModerationLogFilter{
|
||||||
|
Pagination: pagination.PaginationParams{
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
SortOrder: pagination.SortOrderDesc,
|
||||||
|
},
|
||||||
|
Result: c.Query("result"),
|
||||||
|
Endpoint: c.Query("endpoint"),
|
||||||
|
Search: c.Query("search"),
|
||||||
|
}
|
||||||
|
if raw := strings.TrimSpace(c.Query("group_id")); raw != "" {
|
||||||
|
groupID, err := strconv.ParseInt(raw, 10, 64)
|
||||||
|
if err != nil || groupID <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &groupID
|
||||||
|
}
|
||||||
|
if raw := strings.TrimSpace(c.Query("from")); raw != "" {
|
||||||
|
t, _, err := parseContentModerationDate(raw)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid from")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.From = &t
|
||||||
|
}
|
||||||
|
if raw := strings.TrimSpace(c.Query("to")); raw != "" {
|
||||||
|
t, dateOnly, err := parseContentModerationDate(raw)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid to")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if dateOnly {
|
||||||
|
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||||
|
}
|
||||||
|
filter.To = &t
|
||||||
|
}
|
||||||
|
items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) UnbanUser(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64)
|
||||||
|
if err != nil || userID <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid user_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.service.UnbanUser(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) {
|
||||||
|
var req contentModerationHashRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) {
|
||||||
|
result, err := h.service.ClearFlaggedInputHashes(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseContentModerationDate(raw string) (time.Time, bool, error) {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return time.Time{}, false, nil
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||||
|
return t, false, nil
|
||||||
|
}
|
||||||
|
t, err := time.Parse("2006-01-02", raw)
|
||||||
|
return t, err == nil, err
|
||||||
|
}
|
||||||
@ -117,6 +117,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
|
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||||
|
LoginAgreementMode: settings.LoginAgreementMode,
|
||||||
|
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||||
|
LoginAgreementDocuments: loginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
|
||||||
SMTPHost: settings.SMTPHost,
|
SMTPHost: settings.SMTPHost,
|
||||||
SMTPPort: settings.SMTPPort,
|
SMTPPort: settings.SMTPPort,
|
||||||
SMTPUsername: settings.SMTPUsername,
|
SMTPUsername: settings.SMTPUsername,
|
||||||
@ -169,6 +173,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
|
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
|
||||||
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
|
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
|
||||||
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
|
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
|
||||||
|
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||||
|
GitHubOAuthClientID: settings.GitHubOAuthClientID,
|
||||||
|
GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured,
|
||||||
|
GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL,
|
||||||
|
GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL,
|
||||||
|
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||||
|
GoogleOAuthClientID: settings.GoogleOAuthClientID,
|
||||||
|
GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured,
|
||||||
|
GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL,
|
||||||
|
GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL,
|
||||||
SiteName: settings.SiteName,
|
SiteName: settings.SiteName,
|
||||||
SiteLogo: settings.SiteLogo,
|
SiteLogo: settings.SiteLogo,
|
||||||
SiteSubtitle: settings.SiteSubtitle,
|
SiteSubtitle: settings.SiteSubtitle,
|
||||||
@ -185,6 +199,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
DefaultBalance: settings.DefaultBalance,
|
DefaultBalance: settings.DefaultBalance,
|
||||||
|
RiskControlEnabled: settings.RiskControlEnabled,
|
||||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||||
@ -294,17 +309,50 @@ func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.O
|
|||||||
return &service.OpenAIFastPolicySettings{Rules: rules}
|
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||||
|
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
result = append(result, dto.LoginAgreementDocument{
|
||||||
|
ID: item.ID,
|
||||||
|
Title: item.Title,
|
||||||
|
ContentMD: item.ContentMD,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func loginAgreementDocumentsToService(items []dto.LoginAgreementDocument) []service.LoginAgreementDocument {
|
||||||
|
result := make([]service.LoginAgreementDocument, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
title := strings.TrimSpace(item.Title)
|
||||||
|
content := strings.TrimSpace(item.ContentMD)
|
||||||
|
if title == "" && content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, service.LoginAgreementDocument{
|
||||||
|
ID: strings.TrimSpace(item.ID),
|
||||||
|
Title: title,
|
||||||
|
ContentMD: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSettingsRequest 更新设置请求
|
// UpdateSettingsRequest 更新设置请求
|
||||||
type UpdateSettingsRequest struct {
|
type UpdateSettingsRequest struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
FrontendURL string `json:"frontend_url"`
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||||
|
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||||
|
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||||
|
LoginAgreementDocuments []dto.LoginAgreementDocument `json:"login_agreement_documents"`
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
@ -368,6 +416,17 @@ type UpdateSettingsRequest struct {
|
|||||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||||
|
|
||||||
|
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||||
|
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||||
|
GitHubOAuthClientSecret string `json:"github_oauth_client_secret"`
|
||||||
|
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||||
|
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||||
|
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||||
|
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||||
|
GoogleOAuthClientSecret string `json:"google_oauth_client_secret"`
|
||||||
|
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||||
|
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
@ -413,6 +472,16 @@ type UpdateSettingsRequest struct {
|
|||||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||||
|
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||||
|
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||||
|
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||||
|
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||||
|
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||||
|
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||||
|
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||||
|
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||||
|
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||||
|
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||||
|
|
||||||
// Model fallback configuration
|
// Model fallback configuration
|
||||||
@ -497,6 +566,9 @@ type UpdateSettingsRequest struct {
|
|||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
|
// 风控中心功能开关
|
||||||
|
RiskControlEnabled *bool `json:"risk_control_enabled"`
|
||||||
|
|
||||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||||
}
|
}
|
||||||
@ -633,6 +705,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
loginAgreementMode := strings.ToLower(strings.TrimSpace(req.LoginAgreementMode))
|
||||||
|
if loginAgreementMode == "" {
|
||||||
|
loginAgreementMode = strings.ToLower(strings.TrimSpace(previousSettings.LoginAgreementMode))
|
||||||
|
}
|
||||||
|
switch loginAgreementMode {
|
||||||
|
case "", "modal":
|
||||||
|
loginAgreementMode = "modal"
|
||||||
|
case "checkbox":
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Login agreement mode must be modal or checkbox")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
loginAgreementUpdatedAt := strings.TrimSpace(req.LoginAgreementUpdatedAt)
|
||||||
|
if loginAgreementUpdatedAt == "" {
|
||||||
|
loginAgreementUpdatedAt = strings.TrimSpace(previousSettings.LoginAgreementUpdatedAt)
|
||||||
|
}
|
||||||
|
loginAgreementDocuments := loginAgreementDocumentsToService(req.LoginAgreementDocuments)
|
||||||
|
if len(loginAgreementDocuments) == 0 {
|
||||||
|
loginAgreementDocuments = previousSettings.LoginAgreementDocuments
|
||||||
|
}
|
||||||
|
for _, doc := range loginAgreementDocuments {
|
||||||
|
if strings.TrimSpace(doc.Title) == "" {
|
||||||
|
response.BadRequest(c, "Login agreement document title is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(doc.Title) > 80 {
|
||||||
|
response.BadRequest(c, "Login agreement document title is too long (max 80 characters)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(doc.ContentMD) > 200*1024 {
|
||||||
|
response.BadRequest(c, "Login agreement document content is too large (max 200KB)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.LoginAgreementEnabled && len(loginAgreementDocuments) == 0 {
|
||||||
|
response.BadRequest(c, "Login agreement documents are required when enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// LinuxDo Connect 参数验证
|
// LinuxDo Connect 参数验证
|
||||||
if req.LinuxDoConnectEnabled {
|
if req.LinuxDoConnectEnabled {
|
||||||
@ -994,17 +1104,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
|
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(item.URL) == "" {
|
urlTrimmed := strings.TrimSpace(item.URL)
|
||||||
response.BadRequest(c, "Custom menu item URL is required")
|
if strings.HasPrefix(urlTrimmed, "md:") {
|
||||||
return
|
// Markdown page mode: URL = "md:<slug>"
|
||||||
}
|
slug := strings.TrimPrefix(urlTrimmed, "md:")
|
||||||
if len(item.URL) > maxMenuItemURLLen {
|
if slug == "" {
|
||||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
|
} else {
|
||||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
|
if urlTrimmed == "" {
|
||||||
return
|
response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(item.URL) > maxMenuItemURLLen {
|
||||||
|
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil {
|
||||||
|
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:<slug>")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if item.Visibility != "user" && item.Visibility != "admin" {
|
if item.Visibility != "user" && item.Visibility != "admin" {
|
||||||
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
|
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
|
||||||
@ -1148,6 +1268,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
FrontendURL: req.FrontendURL,
|
FrontendURL: req.FrontendURL,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
|
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||||
|
LoginAgreementMode: loginAgreementMode,
|
||||||
|
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||||
|
LoginAgreementDocuments: loginAgreementDocuments,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
SMTPPort: req.SMTPPort,
|
SMTPPort: req.SMTPPort,
|
||||||
SMTPUsername: req.SMTPUsername,
|
SMTPUsername: req.SMTPUsername,
|
||||||
@ -1200,6 +1324,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||||
|
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||||
|
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||||
|
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||||
|
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||||
|
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||||
|
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||||
|
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||||
|
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||||
|
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||||
|
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||||
SiteName: req.SiteName,
|
SiteName: req.SiteName,
|
||||||
SiteLogo: req.SiteLogo,
|
SiteLogo: req.SiteLogo,
|
||||||
SiteSubtitle: req.SiteSubtitle,
|
SiteSubtitle: req.SiteSubtitle,
|
||||||
@ -1365,6 +1499,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return previousSettings.AffiliateEnabled
|
return previousSettings.AffiliateEnabled
|
||||||
}(),
|
}(),
|
||||||
|
RiskControlEnabled: func() bool {
|
||||||
|
if req.RiskControlEnabled != nil {
|
||||||
|
return *req.RiskControlEnabled
|
||||||
|
}
|
||||||
|
return previousSettings.RiskControlEnabled
|
||||||
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||||
@ -1396,6 +1536,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||||
},
|
},
|
||||||
|
GitHub: service.ProviderDefaultGrantSettings{
|
||||||
|
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
|
||||||
|
Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency),
|
||||||
|
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
|
||||||
|
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
|
||||||
|
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
|
||||||
|
},
|
||||||
|
Google: service.ProviderDefaultGrantSettings{
|
||||||
|
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
|
||||||
|
Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency),
|
||||||
|
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
|
||||||
|
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||||
|
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||||
|
},
|
||||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||||
}
|
}
|
||||||
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||||
@ -1486,6 +1640,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
|
LoginAgreementEnabled: updatedSettings.LoginAgreementEnabled,
|
||||||
|
LoginAgreementMode: updatedSettings.LoginAgreementMode,
|
||||||
|
LoginAgreementUpdatedAt: updatedSettings.LoginAgreementUpdatedAt,
|
||||||
|
LoginAgreementDocuments: loginAgreementDocumentsToDTO(updatedSettings.LoginAgreementDocuments),
|
||||||
SMTPHost: updatedSettings.SMTPHost,
|
SMTPHost: updatedSettings.SMTPHost,
|
||||||
SMTPPort: updatedSettings.SMTPPort,
|
SMTPPort: updatedSettings.SMTPPort,
|
||||||
SMTPUsername: updatedSettings.SMTPUsername,
|
SMTPUsername: updatedSettings.SMTPUsername,
|
||||||
@ -1538,6 +1696,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
|
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
|
||||||
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
|
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
|
||||||
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
|
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
|
||||||
|
GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled,
|
||||||
|
GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID,
|
||||||
|
GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured,
|
||||||
|
GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL,
|
||||||
|
GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL,
|
||||||
|
GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled,
|
||||||
|
GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID,
|
||||||
|
GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured,
|
||||||
|
GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL,
|
||||||
|
GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL,
|
||||||
SiteName: updatedSettings.SiteName,
|
SiteName: updatedSettings.SiteName,
|
||||||
SiteLogo: updatedSettings.SiteLogo,
|
SiteLogo: updatedSettings.SiteLogo,
|
||||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||||
@ -1616,6 +1784,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||||
|
|
||||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||||
|
|
||||||
|
RiskControlEnabled: updatedSettings.RiskControlEnabled,
|
||||||
}
|
}
|
||||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||||
@ -1685,6 +1855,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.TotpEnabled != after.TotpEnabled {
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
changed = append(changed, "totp_enabled")
|
changed = append(changed, "totp_enabled")
|
||||||
}
|
}
|
||||||
|
if before.LoginAgreementEnabled != after.LoginAgreementEnabled {
|
||||||
|
changed = append(changed, "login_agreement_enabled")
|
||||||
|
}
|
||||||
|
if before.LoginAgreementMode != after.LoginAgreementMode {
|
||||||
|
changed = append(changed, "login_agreement_mode")
|
||||||
|
}
|
||||||
|
if before.LoginAgreementUpdatedAt != after.LoginAgreementUpdatedAt {
|
||||||
|
changed = append(changed, "login_agreement_updated_at")
|
||||||
|
}
|
||||||
|
if !equalLoginAgreementDocuments(before.LoginAgreementDocuments, after.LoginAgreementDocuments) {
|
||||||
|
changed = append(changed, "login_agreement_documents")
|
||||||
|
}
|
||||||
if before.SMTPHost != after.SMTPHost {
|
if before.SMTPHost != after.SMTPHost {
|
||||||
changed = append(changed, "smtp_host")
|
changed = append(changed, "smtp_host")
|
||||||
}
|
}
|
||||||
@ -2004,6 +2186,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.AffiliateEnabled != after.AffiliateEnabled {
|
if before.AffiliateEnabled != after.AffiliateEnabled {
|
||||||
changed = append(changed, "affiliate_enabled")
|
changed = append(changed, "affiliate_enabled")
|
||||||
}
|
}
|
||||||
|
if before.RiskControlEnabled != after.RiskControlEnabled {
|
||||||
|
changed = append(changed, "risk_control_enabled")
|
||||||
|
}
|
||||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
@ -2027,6 +2212,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
|||||||
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
|
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
|
||||||
{name: "oidc", before: before.OIDC, after: after.OIDC},
|
{name: "oidc", before: before.OIDC, after: after.OIDC},
|
||||||
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||||
|
{name: "github", before: before.GitHub, after: after.GitHub},
|
||||||
|
{name: "google", before: before.Google, after: after.Google},
|
||||||
}
|
}
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.before.Balance != field.after.Balance {
|
if field.before.Balance != field.after.Balance {
|
||||||
@ -2141,6 +2328,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
|||||||
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
|
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
|
||||||
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
|
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
|
||||||
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
|
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
|
||||||
|
data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance
|
||||||
|
data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency
|
||||||
|
data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions
|
||||||
|
data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup
|
||||||
|
data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind
|
||||||
|
data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance
|
||||||
|
data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency
|
||||||
|
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
|
||||||
|
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
|
||||||
|
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
|
||||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@ -2170,6 +2367,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalLoginAgreementDocuments(a, b []service.LoginAgreementDocument) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i].ID != b[i].ID || a[i].Title != b[i].Title || a[i].ContentMD != b[i].ContentMD {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func equalIntSlice(a, b []int) bool {
|
func equalIntSlice(a, b []int) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, status)
|
response.Success(c, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchUpdateConcurrency 批量修改用户并发数
|
||||||
|
// POST /api/v1/admin/users/batch-concurrency
|
||||||
|
type BatchUpdateConcurrencyRequest struct {
|
||||||
|
UserIDs []int64 `json:"user_ids"`
|
||||||
|
All bool `json:"all"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
|
Mode string `json:"mode" binding:"required,oneof=set add"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
|
||||||
|
var req BatchUpdateConcurrencyRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.All && len(req.UserIDs) == 0 {
|
||||||
|
response.BadRequest(c, "user_ids is required unless all=true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.UserIDs) > 500 {
|
||||||
|
response.BadRequest(c, "user_ids cannot exceed 500")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var userIDs []int64
|
||||||
|
if req.All {
|
||||||
|
// Fetch all user IDs via pagination
|
||||||
|
page := 1
|
||||||
|
const pageSize = 500
|
||||||
|
for {
|
||||||
|
users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc")
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, u := range users {
|
||||||
|
userIDs = append(userIDs, u.ID)
|
||||||
|
}
|
||||||
|
if len(users) < pageSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
page++
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
userIDs = req.UserIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
response.Success(c, gin.H{"affected": 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"affected": affected})
|
||||||
|
}
|
||||||
|
|||||||
621
backend/internal/handler/auth_email_oauth.go
Normal file
621
backend/internal/handler/auth_email_oauth.go
Normal file
@ -0,0 +1,621 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
emailOAuthCookiePath = "/api/v1/auth/oauth"
|
||||||
|
emailOAuthStateCookieName = "email_oauth_state"
|
||||||
|
emailOAuthRedirectCookie = "email_oauth_redirect"
|
||||||
|
emailOAuthProviderCookie = "email_oauth_provider"
|
||||||
|
emailOAuthAffiliateCookie = "email_oauth_affiliate"
|
||||||
|
emailOAuthCookieMaxAgeSec = 10 * 60
|
||||||
|
emailOAuthDefaultRedirect = "/dashboard"
|
||||||
|
)
|
||||||
|
|
||||||
|
type emailOAuthTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type emailOAuthProfile struct {
|
||||||
|
Subject string
|
||||||
|
Email string
|
||||||
|
EmailVerified bool
|
||||||
|
Username string
|
||||||
|
DisplayName string
|
||||||
|
AvatarURL string
|
||||||
|
Metadata map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") }
|
||||||
|
func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") }
|
||||||
|
|
||||||
|
func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") }
|
||||||
|
func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") }
|
||||||
|
func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) {
|
||||||
|
h.completeEmailOAuthRegistration(c, "github")
|
||||||
|
}
|
||||||
|
func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) {
|
||||||
|
h.completeEmailOAuthRegistration(c, "google")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) {
|
||||||
|
cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := oauth.GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||||
|
if redirectTo == "" {
|
||||||
|
redirectTo = emailOAuthDefaultRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie := isRequestHTTPS(c)
|
||||||
|
emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie)
|
||||||
|
emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie)
|
||||||
|
emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie)
|
||||||
|
if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" {
|
||||||
|
emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie)
|
||||||
|
} else {
|
||||||
|
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL, err := buildEmailOAuthAuthorizeURL(cfg, state)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Redirect(http.StatusFound, authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) {
|
||||||
|
cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider)
|
||||||
|
if cfgErr != nil {
|
||||||
|
response.ErrorFrom(c, cfgErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||||
|
if frontendCallback == "" {
|
||||||
|
frontendCallback = "/auth/oauth/callback"
|
||||||
|
}
|
||||||
|
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := strings.TrimSpace(c.Query("code"))
|
||||||
|
state := strings.TrimSpace(c.Query("state"))
|
||||||
|
if code == "" || state == "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie := isRequestHTTPS(c)
|
||||||
|
defer func() {
|
||||||
|
emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie)
|
||||||
|
emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie)
|
||||||
|
emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie)
|
||||||
|
emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie)
|
||||||
|
}()
|
||||||
|
expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName)
|
||||||
|
if err != nil || expectedState == "" || expectedState != state {
|
||||||
|
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie)
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) {
|
||||||
|
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie)
|
||||||
|
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||||
|
if redirectTo == "" {
|
||||||
|
redirectTo = emailOAuthDefaultRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code)
|
||||||
|
if err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp)
|
||||||
|
if err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthCallbackWithProfile(
|
||||||
|
c *gin.Context,
|
||||||
|
provider string,
|
||||||
|
cfg config.EmailOAuthProviderConfig,
|
||||||
|
frontendCallback string,
|
||||||
|
redirectTo string,
|
||||||
|
profile *emailOAuthProfile,
|
||||||
|
) {
|
||||||
|
input := service.EmailOAuthIdentityInput{
|
||||||
|
ProviderType: provider,
|
||||||
|
ProviderKey: provider,
|
||||||
|
ProviderSubject: profile.Subject,
|
||||||
|
Email: profile.Email,
|
||||||
|
EmailVerified: profile.EmailVerified,
|
||||||
|
Username: profile.Username,
|
||||||
|
DisplayName: profile.DisplayName,
|
||||||
|
AvatarURL: profile.AvatarURL,
|
||||||
|
UpstreamMetadata: profile.Metadata,
|
||||||
|
}
|
||||||
|
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||||
|
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||||
|
return
|
||||||
|
} else if shouldCreate {
|
||||||
|
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectToFrontendCallback(c, frontendCallback)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||||
|
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectToFrontendCallback(c, frontendCallback)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment := url.Values{}
|
||||||
|
fragment.Set("access_token", tokenPair.AccessToken)
|
||||||
|
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||||
|
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||||
|
fragment.Set("token_type", "Bearer")
|
||||||
|
fragment.Set("redirect", redirectTo)
|
||||||
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
|
||||||
|
ProviderType: strings.TrimSpace(input.ProviderType),
|
||||||
|
ProviderKey: strings.TrimSpace(input.ProviderKey),
|
||||||
|
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||||
|
if identityUser != nil {
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||||
|
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
|
||||||
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil {
|
||||||
|
return strings.TrimSpace(code)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
|
||||||
|
c *gin.Context,
|
||||||
|
provider string,
|
||||||
|
frontendCallback string,
|
||||||
|
redirectTo string,
|
||||||
|
profile *emailOAuthProfile,
|
||||||
|
) error {
|
||||||
|
if h == nil || profile == nil {
|
||||||
|
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
browserSessionKey, err := generateOAuthPendingBrowserSession()
|
||||||
|
if err != nil {
|
||||||
|
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
|
||||||
|
}
|
||||||
|
setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c))
|
||||||
|
|
||||||
|
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||||
|
username := strings.TrimSpace(profile.Username)
|
||||||
|
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||||
|
upstreamClaims := map[string]any{
|
||||||
|
"email": email,
|
||||||
|
"email_verified": profile.EmailVerified,
|
||||||
|
"username": username,
|
||||||
|
"provider": provider,
|
||||||
|
"provider_key": provider,
|
||||||
|
"provider_subject": strings.TrimSpace(profile.Subject),
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(profile.DisplayName) != "" {
|
||||||
|
upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(profile.AvatarURL) != "" {
|
||||||
|
upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL)
|
||||||
|
}
|
||||||
|
if affiliateCode != "" {
|
||||||
|
upstreamClaims["aff_code"] = affiliateCode
|
||||||
|
}
|
||||||
|
for key, value := range profile.Metadata {
|
||||||
|
if _, exists := upstreamClaims[key]; !exists {
|
||||||
|
upstreamClaims[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
|
||||||
|
pendingError := "registration_completion_required"
|
||||||
|
choiceReason := "registration_completion_required"
|
||||||
|
if invitationRequired {
|
||||||
|
pendingError = "invitation_required"
|
||||||
|
choiceReason = "invitation_required"
|
||||||
|
}
|
||||||
|
completionResponse := map[string]any{
|
||||||
|
"step": oauthPendingChoiceStep,
|
||||||
|
"error": pendingError,
|
||||||
|
"choice_reason": choiceReason,
|
||||||
|
"adoption_required": false,
|
||||||
|
"create_account_allowed": true,
|
||||||
|
"existing_account_bindable": false,
|
||||||
|
"force_email_on_signup": true,
|
||||||
|
"invitation_required": invitationRequired,
|
||||||
|
"email": email,
|
||||||
|
"resolved_email": email,
|
||||||
|
"provider": provider,
|
||||||
|
"redirect": redirectTo,
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(frontendCallback) != "" {
|
||||||
|
completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||||
|
Intent: oauthIntentLogin,
|
||||||
|
Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)},
|
||||||
|
ResolvedEmail: email,
|
||||||
|
RedirectTo: redirectTo,
|
||||||
|
BrowserSessionKey: browserSessionKey,
|
||||||
|
UpstreamIdentityClaims: upstreamClaims,
|
||||||
|
CompletionResponse: completionResponse,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type completeEmailOAuthRequest struct {
|
||||||
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
|
InvitationCode string `json:"invitation_code,omitempty"`
|
||||||
|
AffCode string `json:"aff_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) {
|
||||||
|
var req completeEmailOAuthRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||||
|
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
affiliateCode := strings.TrimSpace(req.AffCode)
|
||||||
|
if affiliateCode == "" {
|
||||||
|
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
|
||||||
|
c.Request.Context(),
|
||||||
|
strings.TrimSpace(session.ResolvedEmail),
|
||||||
|
req.Password,
|
||||||
|
strings.TrimSpace(req.InvitationCode),
|
||||||
|
strings.TrimSpace(session.ProviderType),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx, err := client.Tx(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
|
||||||
|
sessionForBinding := *session
|
||||||
|
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
|
||||||
|
if strings.TrimSpace(req.InvitationCode) != "" {
|
||||||
|
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
|
||||||
|
}
|
||||||
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||||
|
txCtx,
|
||||||
|
user,
|
||||||
|
strings.TrimSpace(req.InvitationCode),
|
||||||
|
strings.TrimSpace(session.ProviderType),
|
||||||
|
affiliateCode,
|
||||||
|
); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
clearCookies()
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||||
|
clearCookies()
|
||||||
|
writeOAuthTokenPairResponse(c, tokenPair)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||||
|
if h != nil && h.settingSvc != nil {
|
||||||
|
return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider)
|
||||||
|
}
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) {
|
||||||
|
u, err := url.Parse(cfg.AuthorizeURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("response_type", "code")
|
||||||
|
q.Set("client_id", cfg.ClientID)
|
||||||
|
q.Set("redirect_uri", cfg.RedirectURL)
|
||||||
|
q.Set("state", state)
|
||||||
|
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||||
|
q.Set("scope", cfg.Scopes)
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) {
|
||||||
|
resp, err := req.C().
|
||||||
|
R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetFormData(map[string]string{
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": cfg.ClientID,
|
||||||
|
"client_secret": cfg.ClientSecret,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": cfg.RedirectURL,
|
||||||
|
}).
|
||||||
|
Post(cfg.TokenURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||||
|
}
|
||||||
|
var tokenResp emailOAuthTokenResponse
|
||||||
|
if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||||
|
return nil, errors.New("missing access_token")
|
||||||
|
}
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) {
|
||||||
|
resp, err := req.C().
|
||||||
|
R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetBearerAuthToken(token.AccessToken).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
Get(cfg.UserInfoURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||||
|
case "github":
|
||||||
|
return parseGitHubOAuthProfile(ctx, cfg, token, resp.String())
|
||||||
|
case "google":
|
||||||
|
return parseGoogleOAuthProfile(resp.String())
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported oauth provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) {
|
||||||
|
subject := strings.TrimSpace(gjson.Get(body, "id").String())
|
||||||
|
if subject == "" {
|
||||||
|
return nil, errors.New("github user id is missing")
|
||||||
|
}
|
||||||
|
email := ""
|
||||||
|
emailsURL := strings.TrimSpace(cfg.EmailsURL)
|
||||||
|
if emailsURL == "" {
|
||||||
|
return nil, errors.New("github verified email is missing")
|
||||||
|
}
|
||||||
|
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
email = verifiedEmail
|
||||||
|
if email == "" {
|
||||||
|
return nil, errors.New("github verified email is missing")
|
||||||
|
}
|
||||||
|
login := strings.TrimSpace(gjson.Get(body, "login").String())
|
||||||
|
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||||
|
return &emailOAuthProfile{
|
||||||
|
Subject: subject,
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: true,
|
||||||
|
Username: firstNonEmpty(login, name, "github_"+subject),
|
||||||
|
DisplayName: firstNonEmpty(name, login),
|
||||||
|
AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()),
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"login": login,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) {
|
||||||
|
resp, err := req.C().
|
||||||
|
R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetBearerAuthToken(accessToken).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
Get(emailsURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024))
|
||||||
|
}
|
||||||
|
items := gjson.Parse(resp.String()).Array()
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Get("primary").Bool() && item.Get("verified").Bool() {
|
||||||
|
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Get("verified").Bool() {
|
||||||
|
if email := strings.TrimSpace(item.Get("email").String()); email != "" {
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.New("github verified email is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) {
|
||||||
|
subject := strings.TrimSpace(gjson.Get(body, "sub").String())
|
||||||
|
email := strings.TrimSpace(gjson.Get(body, "email").String())
|
||||||
|
verified := gjson.Get(body, "email_verified").Bool()
|
||||||
|
if subject == "" {
|
||||||
|
return nil, errors.New("google subject is missing")
|
||||||
|
}
|
||||||
|
if email == "" || !verified {
|
||||||
|
return nil, errors.New("google verified email is missing")
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(gjson.Get(body, "name").String())
|
||||||
|
return &emailOAuthProfile{
|
||||||
|
Subject: subject,
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: true,
|
||||||
|
Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email),
|
||||||
|
DisplayName: name,
|
||||||
|
AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()),
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"email_verified": true,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) {
|
||||||
|
http.SetCookie(c.Writer, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: emailOAuthCookiePath,
|
||||||
|
MaxAge: emailOAuthCookieMaxAgeSec,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailOAuthClearCookie(c *gin.Context, name string, secure bool) {
|
||||||
|
http.SetCookie(c.Writer, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: "",
|
||||||
|
Path: emailOAuthCookiePath,
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
414
backend/internal/handler/auth_email_oauth_test.go
Normal file
414
backend/internal/handler/auth_email_oauth_test.go
Normal file
@ -0,0 +1,414 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
state := "github-oauth-state"
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")})
|
||||||
|
req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
profile := &emailOAuthProfile{
|
||||||
|
Subject: "github-123",
|
||||||
|
Email: "fresh@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Username: "fresh",
|
||||||
|
DisplayName: "Fresh User",
|
||||||
|
AvatarURL: "https://cdn.example/fresh.png",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"login": "fresh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ClientID: "github-client",
|
||||||
|
ClientSecret: "github-secret",
|
||||||
|
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||||
|
FrontendRedirectURL: "/auth/oauth/callback",
|
||||||
|
}, "/auth/oauth/callback", "/dashboard", profile)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
require.Contains(t, location, "/auth/oauth/callback")
|
||||||
|
require.NotContains(t, location, "access_token=")
|
||||||
|
|
||||||
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "github", session.ProviderType)
|
||||||
|
require.Equal(t, "github", session.ProviderKey)
|
||||||
|
require.Equal(t, "github-123", session.ProviderSubject)
|
||||||
|
require.Equal(t, "fresh@example.com", session.ResolvedEmail)
|
||||||
|
require.Equal(t, "/dashboard", session.RedirectTo)
|
||||||
|
require.Nil(t, session.TargetUserID)
|
||||||
|
|
||||||
|
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||||
|
require.Equal(t, "invitation_required", completion["error"])
|
||||||
|
require.Equal(t, true, completion["invitation_required"])
|
||||||
|
require.Equal(t, "fresh@example.com", completion["email"])
|
||||||
|
require.Equal(t, "fresh@example.com", completion["resolved_email"])
|
||||||
|
require.Equal(t, true, completion["create_account_allowed"])
|
||||||
|
|
||||||
|
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||||
|
require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("existing@example.com").
|
||||||
|
SetUsername("existing").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil)
|
||||||
|
|
||||||
|
handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ClientID: "google-client",
|
||||||
|
ClientSecret: "google-secret",
|
||||||
|
RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback",
|
||||||
|
FrontendRedirectURL: "/auth/oauth/callback",
|
||||||
|
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||||
|
Subject: "google-123",
|
||||||
|
Email: "existing@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Username: "existing",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
require.Contains(t, location, "access_token=")
|
||||||
|
require.Contains(t, location, "redirect=%252Fdashboard")
|
||||||
|
|
||||||
|
sessionCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, sessionCount)
|
||||||
|
|
||||||
|
identityCount, err := client.AuthIdentity.Query().Where(
|
||||||
|
authidentity.ProviderTypeEQ("google"),
|
||||||
|
authidentity.ProviderSubjectEQ("google-123"),
|
||||||
|
).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, identityCount)
|
||||||
|
_ = user
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
|
||||||
|
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
settingValues: map[string]string{
|
||||||
|
service.SettingKeyAffiliateEnabled: "true",
|
||||||
|
},
|
||||||
|
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||||
|
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ClientID: "github-client",
|
||||||
|
ClientSecret: "github-secret",
|
||||||
|
RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback",
|
||||||
|
FrontendRedirectURL: "/auth/oauth/callback",
|
||||||
|
}, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{
|
||||||
|
Subject: "github-aff-user",
|
||||||
|
Email: "aff-user@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Username: "aff-user",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
|
||||||
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
require.Empty(t, affiliateRepo.ensureUserIDs)
|
||||||
|
require.Empty(t, affiliateRepo.bindCalls)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
|
||||||
|
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
|
||||||
|
|
||||||
|
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||||
|
require.Equal(t, "registration_completion_required", completion["error"])
|
||||||
|
require.Equal(t, false, completion["invitation_required"])
|
||||||
|
require.Equal(t, true, completion["create_account_allowed"])
|
||||||
|
require.Equal(t, true, completion["force_email_on_signup"])
|
||||||
|
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
|
||||||
|
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002})
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
invitationEnabled: true,
|
||||||
|
settingValues: map[string]string{
|
||||||
|
service.SettingKeyAffiliateEnabled: "true",
|
||||||
|
},
|
||||||
|
affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService {
|
||||||
|
return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
invitation, err := client.RedeemCode.Create().
|
||||||
|
SetCode("INVITE456").
|
||||||
|
SetType(service.RedeemTypeInvitation).
|
||||||
|
SetStatus(service.StatusUnused).
|
||||||
|
SetValue(0).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("email-oauth-aff-session-token").
|
||||||
|
SetIntent(oauthIntentLogin).
|
||||||
|
SetProviderType("google").
|
||||||
|
SetProviderKey("google").
|
||||||
|
SetProviderSubject("google-aff-user").
|
||||||
|
SetResolvedEmail("pending-aff@example.com").
|
||||||
|
SetRedirectTo("/dashboard").
|
||||||
|
SetBrowserSessionKey("browser-aff-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"email": "pending-aff@example.com",
|
||||||
|
"email_verified": true,
|
||||||
|
"username": "pending-aff",
|
||||||
|
"provider": "google",
|
||||||
|
"provider_key": "google",
|
||||||
|
"provider_subject": "google-aff-user",
|
||||||
|
"aff_code": "AFF456",
|
||||||
|
}).
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
"step": oauthPendingChoiceStep,
|
||||||
|
"error": "invitation_required",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.completeEmailOAuthRegistration(c, "google")
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, user.PasswordHash)
|
||||||
|
require.NotEqual(t, "secret-123", user.PasswordHash)
|
||||||
|
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, tamperedCount)
|
||||||
|
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
|
||||||
|
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, storedInvitation.UsedBy)
|
||||||
|
require.Equal(t, user.ID, *storedInvitation.UsedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("email-oauth-password-session-token").
|
||||||
|
SetIntent(oauthIntentLogin).
|
||||||
|
SetProviderType("github").
|
||||||
|
SetProviderKey("github").
|
||||||
|
SetProviderSubject("github-password-user").
|
||||||
|
SetResolvedEmail("password-required@example.com").
|
||||||
|
SetRedirectTo("/dashboard").
|
||||||
|
SetBrowserSessionKey("browser-password-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"email": "password-required@example.com",
|
||||||
|
"email_verified": true,
|
||||||
|
"username": "password-required",
|
||||||
|
"provider": "github",
|
||||||
|
"provider_key": "github",
|
||||||
|
"provider_subject": "github-password-user",
|
||||||
|
}).
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
"step": oauthPendingChoiceStep,
|
||||||
|
"error": "registration_completion_required",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.completeEmailOAuthRegistration(c, "github")
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
|
||||||
|
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
http.Error(w, "missing scope", http.StatusForbidden)
|
||||||
|
}))
|
||||||
|
t.Cleanup(emailServer.Close)
|
||||||
|
|
||||||
|
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
|
||||||
|
EmailsURL: emailServer.URL,
|
||||||
|
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, profile)
|
||||||
|
require.Contains(t, err.Error(), "github emails endpoint status 403")
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthEmailAffiliateBindCall struct {
|
||||||
|
userID int64
|
||||||
|
inviterID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthEmailAffiliateRepoStub struct {
|
||||||
|
codeOwners map[string]int64
|
||||||
|
ensureUserIDs []int64
|
||||||
|
bindCalls []oauthEmailAffiliateBindCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub {
|
||||||
|
return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) {
|
||||||
|
r.ensureUserIDs = append(r.ensureUserIDs, userID)
|
||||||
|
return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) {
|
||||||
|
userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))]
|
||||||
|
if !ok {
|
||||||
|
return nil, service.ErrAffiliateProfileNotFound
|
||||||
|
}
|
||||||
|
return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) {
|
||||||
|
r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID})
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) {
|
||||||
|
panic("unexpected AccrueQuota call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) {
|
||||||
|
panic("unexpected GetAccruedRebateFromInvitee call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) {
|
||||||
|
panic("unexpected ThawFrozenQuota call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) {
|
||||||
|
panic("unexpected TransferQuotaToBalance call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) {
|
||||||
|
panic("unexpected ListInvitees call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error {
|
||||||
|
panic("unexpected UpdateUserAffCode call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) {
|
||||||
|
panic("unexpected ResetUserAffCode call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error {
|
||||||
|
panic("unexpected SetUserRebateRate call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error {
|
||||||
|
panic("unexpected BatchSetUserRebateRate call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
|
||||||
|
panic("unexpected ListUsersWithCustomSettings call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
|
||||||
|
panic("unexpected ListAffiliateInviteRecords call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
|
||||||
|
panic("unexpected ListAffiliateRebateRecords call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
|
||||||
|
panic("unexpected ListAffiliateTransferRecords call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) {
|
||||||
|
panic("unexpected GetAffiliateUserOverview call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSetCookieValue(cookies []*http.Cookie, name string) string {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 {
|
||||||
|
return cookie.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct {
|
|||||||
emailCache service.EmailCache
|
emailCache service.EmailCache
|
||||||
settingValues map[string]string
|
settingValues map[string]string
|
||||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||||
|
affiliateService *service.AffiliateService
|
||||||
|
affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService
|
||||||
totpCache service.TotpCache
|
totpCache service.TotpCache
|
||||||
totpEncryptor service.SecretEncryptor
|
totpEncryptor service.SecretEncryptor
|
||||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||||
@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
|||||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
)`)
|
)`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||||
|
user_id INTEGER PRIMARY KEY,
|
||||||
|
aff_code TEXT NOT NULL UNIQUE,
|
||||||
|
aff_code_custom BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
aff_rebate_rate_percent REAL NULL,
|
||||||
|
inviter_id INTEGER NULL,
|
||||||
|
aff_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
aff_quota REAL NOT NULL DEFAULT 0,
|
||||||
|
aff_frozen_quota REAL NOT NULL DEFAULT 0,
|
||||||
|
aff_history_quota REAL NOT NULL DEFAULT 0,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
settingValues := map[string]string{
|
settingValues := map[string]string{
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||||
|
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||||
}
|
}
|
||||||
for key, value := range options.settingValues {
|
for key, value := range options.settingValues {
|
||||||
settingValues[key] = value
|
settingValues[key] = value
|
||||||
}
|
}
|
||||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||||
|
affiliateService := options.affiliateService
|
||||||
|
if affiliateService == nil && options.affiliateFactory != nil {
|
||||||
|
affiliateService = options.affiliateFactory(client, settingSvc)
|
||||||
|
}
|
||||||
userRepo := &oauthPendingFlowUserRepo{
|
userRepo := &oauthPendingFlowUserRepo{
|
||||||
client: client,
|
client: client,
|
||||||
options: options.userRepoOptions,
|
options: options.userRepoOptions,
|
||||||
@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
options.defaultSubAssigner,
|
options.defaultSubAssigner,
|
||||||
nil,
|
affiliateService,
|
||||||
)
|
)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
var totpSvc *service.TotpService
|
var totpSvc *service.TotpService
|
||||||
@ -2798,6 +2820,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
|
|||||||
panic("unexpected UpdateConcurrency call")
|
panic("unexpected UpdateConcurrency call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||||
|
panic("unexpected BatchSetConcurrency call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||||
|
panic("unexpected BatchAddConcurrency call")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||||
return map[int64]*time.Time{}, nil
|
return map[int64]*time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
130
backend/internal/handler/content_moderation_helper.go
Normal file
130
backend/internal/handler/content_moderation_helper.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||||
|
if h == nil || h.contentModerationService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationStatus(decision *service.ContentModerationDecision) int {
|
||||||
|
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
|
||||||
|
return http.StatusForbidden
|
||||||
|
}
|
||||||
|
return decision.StatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
|
||||||
|
return "content_policy_violation"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||||
|
if h == nil || h.contentModerationService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||||
|
if svc == nil || c == nil || c.Request == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
|
||||||
|
if reqLog != nil {
|
||||||
|
reqLog.Info("content_moderation.gateway_check_start",
|
||||||
|
zap.String("request_id", input.RequestID),
|
||||||
|
zap.Int64("user_id", input.UserID),
|
||||||
|
zap.Int64("api_key_id", input.APIKeyID),
|
||||||
|
zap.String("api_key_name", input.APIKeyName),
|
||||||
|
zap.Int64p("group_id", input.GroupID),
|
||||||
|
zap.String("group_name", input.GroupName),
|
||||||
|
zap.String("endpoint", input.Endpoint),
|
||||||
|
zap.String("provider", input.Provider),
|
||||||
|
zap.String("protocol", input.Protocol),
|
||||||
|
zap.String("model", input.Model),
|
||||||
|
zap.Int("body_bytes", len(body)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
decision, err := svc.Check(c.Request.Context(), input)
|
||||||
|
if err != nil {
|
||||||
|
if reqLog != nil {
|
||||||
|
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if reqLog != nil && decision != nil {
|
||||||
|
reqLog.Info("content_moderation.gateway_check_done",
|
||||||
|
zap.String("request_id", input.RequestID),
|
||||||
|
zap.Bool("allowed", decision.Allowed),
|
||||||
|
zap.Bool("blocked", decision.Blocked),
|
||||||
|
zap.Bool("flagged", decision.Flagged),
|
||||||
|
zap.String("action", decision.Action),
|
||||||
|
zap.Int("status_code", decision.StatusCode),
|
||||||
|
zap.String("highest_category", decision.HighestCategory),
|
||||||
|
zap.Float64("highest_score", decision.HighestScore),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return decision
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
|
||||||
|
input := service.ContentModerationCheckInput{
|
||||||
|
RequestID: contentModerationRequestID(c.Request.Context()),
|
||||||
|
UserID: subject.UserID,
|
||||||
|
Endpoint: GetInboundEndpoint(c),
|
||||||
|
Provider: contentModerationProvider(apiKey),
|
||||||
|
Model: strings.TrimSpace(model),
|
||||||
|
Protocol: protocol,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||||
|
input.Provider = strings.TrimSpace(forcedPlatform)
|
||||||
|
}
|
||||||
|
if apiKey != nil {
|
||||||
|
input.APIKeyID = apiKey.ID
|
||||||
|
input.APIKeyName = apiKey.Name
|
||||||
|
if apiKey.User != nil {
|
||||||
|
input.UserEmail = apiKey.User.Email
|
||||||
|
}
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
groupID := *apiKey.GroupID
|
||||||
|
input.GroupID = &groupID
|
||||||
|
}
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
input.GroupName = apiKey.Group.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
|
||||||
|
input.Endpoint = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationProvider(apiKey *service.APIKey) string {
|
||||||
|
if apiKey == nil || apiKey.Group == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(apiKey.Group.Platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationRequestID(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
|
||||||
|
return strings.TrimSpace(requestID)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@ -11,6 +11,7 @@ type CustomMenuItem struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
IconSVG string `json:"icon_svg"`
|
IconSVG string `json:"icon_svg"`
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
|
PageSlug string `json:"page_slug,omitempty"`
|
||||||
Visibility string `json:"visibility"` // "user" or "admin"
|
Visibility string `json:"visibility"` // "user" or "admin"
|
||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
}
|
}
|
||||||
@ -24,15 +25,19 @@ type CustomEndpoint struct {
|
|||||||
|
|
||||||
// SystemSettings represents the admin settings API response payload.
|
// SystemSettings represents the admin settings API response payload.
|
||||||
type SystemSettings struct {
|
type SystemSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
FrontendURL string `json:"frontend_url"`
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
|
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||||
|
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||||
|
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||||
|
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||||
|
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
SMTPPort int `json:"smtp_port"`
|
SMTPPort int `json:"smtp_port"`
|
||||||
@ -91,6 +96,17 @@ type SystemSettings struct {
|
|||||||
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
|
||||||
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
|
||||||
|
|
||||||
|
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||||
|
GitHubOAuthClientID string `json:"github_oauth_client_id"`
|
||||||
|
GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"`
|
||||||
|
GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"`
|
||||||
|
GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"`
|
||||||
|
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||||
|
GoogleOAuthClientID string `json:"google_oauth_client_id"`
|
||||||
|
GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"`
|
||||||
|
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"`
|
||||||
|
GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"`
|
||||||
|
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
@ -197,6 +213,9 @@ type SystemSettings struct {
|
|||||||
// Available Channels feature switch (user-facing aggregate view)
|
// Available Channels feature switch (user-facing aggregate view)
|
||||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||||
|
|
||||||
|
// 风控中心功能开关
|
||||||
|
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||||
|
|
||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
@ -210,45 +229,52 @@ type DefaultSubscriptionSetting struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PublicSettings struct {
|
type PublicSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||||
SiteName string `json:"site_name"`
|
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||||
SiteLogo string `json:"site_logo"`
|
LoginAgreementRevision string `json:"login_agreement_revision"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
ContactInfo string `json:"contact_info"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
DocURL string `json:"doc_url"`
|
SiteName string `json:"site_name"`
|
||||||
HomeContent string `json:"home_content"`
|
SiteLogo string `json:"site_logo"`
|
||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
ContactInfo string `json:"contact_info"`
|
||||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
DocURL string `json:"doc_url"`
|
||||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
HomeContent string `json:"home_content"`
|
||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||||
PaymentEnabled bool `json:"payment_enabled"`
|
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||||
Version string `json:"version"`
|
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
|
PaymentEnabled bool `json:"payment_enabled"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||||
|
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||||
|
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||||
|
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||||
|
|
||||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||||
@ -256,6 +282,14 @@ type PublicSettings struct {
|
|||||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||||
|
|
||||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
|
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginAgreementDocument struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
ContentMD string `json:"content_md"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
|
|||||||
@ -46,6 +46,7 @@ type GatewayHandler struct {
|
|||||||
apiKeyService *service.APIKeyService
|
apiKeyService *service.APIKeyService
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||||
errorPassthroughService *service.ErrorPassthroughService
|
errorPassthroughService *service.ErrorPassthroughService
|
||||||
|
contentModerationService *service.ContentModerationService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
userMsgQueueHelper *UserMsgQueueHelper
|
userMsgQueueHelper *UserMsgQueueHelper
|
||||||
requestEventBus *service.RequestEventBus
|
requestEventBus *service.RequestEventBus
|
||||||
@ -68,6 +69,7 @@ func NewGatewayHandler(
|
|||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||||
errorPassthroughService *service.ErrorPassthroughService,
|
errorPassthroughService *service.ErrorPassthroughService,
|
||||||
|
contentModerationService *service.ContentModerationService,
|
||||||
userMsgQueueService *service.UserMessageQueueService,
|
userMsgQueueService *service.UserMessageQueueService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
@ -103,6 +105,7 @@ func NewGatewayHandler(
|
|||||||
apiKeyService: apiKeyService,
|
apiKeyService: apiKeyService,
|
||||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||||
errorPassthroughService: errorPassthroughService,
|
errorPassthroughService: errorPassthroughService,
|
||||||
|
contentModerationService: contentModerationService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
userMsgQueueHelper: umqHelper,
|
userMsgQueueHelper: umqHelper,
|
||||||
requestEventBus: requestEventBus,
|
requestEventBus: requestEventBus,
|
||||||
@ -215,6 +218,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Track if we've started streaming (for error handling)
|
// Track if we've started streaming (for error handling)
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
|
|
||||||
|
|||||||
@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Error passthrough binding
|
// Error passthrough binding
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
|||||||
@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Error passthrough binding
|
// Error passthrough binding
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
|||||||
@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, modelName, stream, body)
|
setOpsRequestContext(c, modelName, stream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
|
||||||
|
googleError(c, contentModerationStatus(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||||
reqModel := modelName // 保存映射前的原始模型名
|
reqModel := modelName // 保存映射前的原始模型名
|
||||||
|
|||||||
@ -33,6 +33,7 @@ type AdminHandlers struct {
|
|||||||
Channel *admin.ChannelHandler
|
Channel *admin.ChannelHandler
|
||||||
ChannelMonitor *admin.ChannelMonitorHandler
|
ChannelMonitor *admin.ChannelMonitorHandler
|
||||||
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
||||||
|
ContentModeration *admin.ContentModerationHandler
|
||||||
Payment *admin.PaymentHandler
|
Payment *admin.PaymentHandler
|
||||||
Windsurf *admin.WindsurfHandler
|
Windsurf *admin.WindsurfHandler
|
||||||
Affiliate *admin.AffiliateHandler
|
Affiliate *admin.AffiliateHandler
|
||||||
|
|||||||
@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
|
|||||||
@ -27,15 +27,16 @@ import (
|
|||||||
|
|
||||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||||
type OpenAIGatewayHandler struct {
|
type OpenAIGatewayHandler struct {
|
||||||
gatewayService *service.OpenAIGatewayService
|
gatewayService *service.OpenAIGatewayService
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
apiKeyService *service.APIKeyService
|
apiKeyService *service.APIKeyService
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||||
errorPassthroughService *service.ErrorPassthroughService
|
errorPassthroughService *service.ErrorPassthroughService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
contentModerationService *service.ContentModerationService
|
||||||
imageLimiter *imageConcurrencyLimiter
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
imageLimiter *imageConcurrencyLimiter
|
||||||
cfg *config.Config
|
maxAccountSwitches int
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||||
@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
|||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||||
errorPassthroughService *service.ErrorPassthroughService,
|
errorPassthroughService *service.ErrorPassthroughService,
|
||||||
|
contentModerationService *service.ContentModerationService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *OpenAIGatewayHandler {
|
) *OpenAIGatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &OpenAIGatewayHandler{
|
return &OpenAIGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
apiKeyService: apiKeyService,
|
apiKeyService: apiKeyService,
|
||||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||||
errorPassthroughService: errorPassthroughService,
|
errorPassthroughService: errorPassthroughService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
contentModerationService: contentModerationService,
|
||||||
imageLimiter: &imageConcurrencyLimiter{},
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
imageLimiter: &imageConcurrencyLimiter{},
|
||||||
cfg: cfg,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||||
@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||||
|
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射
|
||||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||||
|
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
|
||||||
|
writeContentModerationWSError(ctx, wsConn, decision)
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||||
return
|
return
|
||||||
@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
|
|
||||||
hooks := &service.OpenAIWSIngressHooks{
|
hooks := &service.OpenAIWSIngressHooks{
|
||||||
InitialRequestModel: reqModel,
|
InitialRequestModel: reqModel,
|
||||||
|
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||||
|
if turn == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(payload) {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||||
|
}
|
||||||
|
model := strings.TrimSpace(originalModel)
|
||||||
|
if model == "" {
|
||||||
|
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
model = reqModel
|
||||||
|
}
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||||
|
writeContentModerationWSError(ctx, wsConn, decision)
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
BeforeTurn: func(turn int) error {
|
BeforeTurn: func(turn int) error {
|
||||||
if turn == 1 {
|
if turn == 1 {
|
||||||
return nil
|
return nil
|
||||||
@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
|||||||
_ = conn.CloseNow()
|
_ = conn.CloseNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||||
|
if conn == nil || decision == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
message := strings.TrimSpace(decision.Message)
|
||||||
|
if message == "" {
|
||||||
|
message = "content moderation blocked this request"
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(gin.H{
|
||||||
|
"event_id": "evt_content_moderation_blocked",
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": contentModerationErrorCode(decision),
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
|
||||||
|
}
|
||||||
|
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = conn.Write(writeCtx, coderws.MessageText, payload)
|
||||||
|
}
|
||||||
|
|
||||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return "-", "-"
|
return "-", "-"
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
coderws "github.com/coder/websocket"
|
coderws "github.com/coder/websocket"
|
||||||
@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
|||||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type contentModerationHandlerSettingRepo struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||||
|
if value, ok := r.values[key]; ok {
|
||||||
|
return &service.Setting{Key: key, Value: value}, nil
|
||||||
|
}
|
||||||
|
return nil, service.ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
if value, ok := r.values[key]; ok {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return "", service.ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||||
|
if r.values == nil {
|
||||||
|
r.values = map[string]string{}
|
||||||
|
}
|
||||||
|
r.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
out := map[string]string{}
|
||||||
|
for _, key := range keys {
|
||||||
|
if value, ok := r.values[key]; ok {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
if r.values == nil {
|
||||||
|
r.values = map[string]string{}
|
||||||
|
}
|
||||||
|
for key, value := range settings {
|
||||||
|
r.values[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
out := make(map[string]string, len(r.values))
|
||||||
|
for key, value := range r.values {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
|
||||||
|
delete(r.values, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentModerationHandlerTestRepo struct {
|
||||||
|
logs []service.ContentModerationLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||||
|
if log != nil {
|
||||||
|
r.logs = append(r.logs, *log)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||||
|
return &service.ContentModerationCleanupResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||||||
|
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
|
||||||
|
}))
|
||||||
|
defer moderationServer.Close()
|
||||||
|
|
||||||
|
cfg := &service.ContentModerationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Mode: service.ContentModerationModePreBlock,
|
||||||
|
BaseURL: moderationServer.URL,
|
||||||
|
Model: "omni-moderation-latest",
|
||||||
|
APIKeys: []string{"sk-test"},
|
||||||
|
SampleRate: 100,
|
||||||
|
AllGroups: true,
|
||||||
|
BlockMessage: "内容审计测试阻断",
|
||||||
|
}
|
||||||
|
rawCfg, err := json.Marshal(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
repo := &contentModerationHandlerTestRepo{}
|
||||||
|
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
|
||||||
|
service.SettingKeyRiskControlEnabled: "true",
|
||||||
|
service.SettingKeyContentModerationConfig: string(rawCfg),
|
||||||
|
}}
|
||||||
|
moderationSvc := service.NewContentModerationService(
|
||||||
|
settingRepo,
|
||||||
|
repo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
|
||||||
|
UserID: 1,
|
||||||
|
Endpoint: "/v1/responses",
|
||||||
|
Provider: "openai",
|
||||||
|
Model: "gpt-5.5",
|
||||||
|
Protocol: service.ContentModerationProtocolOpenAIResponses,
|
||||||
|
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, decision.Blocked)
|
||||||
|
repo.logs = nil
|
||||||
|
h := &OpenAIGatewayHandler{
|
||||||
|
gatewayService: &service.OpenAIGatewayService{},
|
||||||
|
billingCacheService: &service.BillingCacheService{},
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
contentModerationService: moderationSvc,
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
|
||||||
|
}
|
||||||
|
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
|
||||||
|
"type":"response.create",
|
||||||
|
"model":"gpt-5.5",
|
||||||
|
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
|
||||||
|
}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, payload, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
if readErr == nil {
|
||||||
|
require.Contains(t, string(payload), "content_policy_violation")
|
||||||
|
require.Contains(t, string(payload), "内容审计测试阻断")
|
||||||
|
} else {
|
||||||
|
var closeErr coderws.CloseError
|
||||||
|
require.ErrorAs(t, readErr, &closeErr)
|
||||||
|
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||||
|
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||||
|
}
|
||||||
|
require.Len(t, repo.logs, 1)
|
||||||
|
require.True(t, repo.logs[0].Flagged)
|
||||||
|
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
|
||||||
|
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||||
|
|||||||
@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked {
|
||||||
|
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||||
if !acquired {
|
if !acquired {
|
||||||
return
|
return
|
||||||
|
|||||||
283
backend/internal/handler/page_handler.go
Normal file
283
backend/internal/handler/page_handler.go
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
|
||||||
|
|
||||||
|
const maxPageFileSize = 1 << 20 // 1MB
|
||||||
|
|
||||||
|
type PageHandler struct {
|
||||||
|
pagesDir string
|
||||||
|
settingService *service.SettingService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler {
|
||||||
|
pagesDir := filepath.Join(dataDir, "pages")
|
||||||
|
_ = os.MkdirAll(pagesDir, 0755)
|
||||||
|
return &PageHandler{pagesDir: pagesDir, settingService: settingService}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPageContent serves raw markdown content for a given slug.
|
||||||
|
// GET /api/v1/pages/:slug
|
||||||
|
func (h *PageHandler) GetPageContent(c *gin.Context) {
|
||||||
|
slug := c.Param("slug")
|
||||||
|
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||||
|
response.BadRequest(c, "Invalid page slug")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visibility check: slug must be configured in custom_menu_items
|
||||||
|
// and the user must have permission based on visibility setting
|
||||||
|
if !h.checkSlugVisibility(c, slug) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(h.pagesDir, slug+".md")
|
||||||
|
cleaned := filepath.Clean(filePath)
|
||||||
|
if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) {
|
||||||
|
response.BadRequest(c, "Invalid page slug")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(cleaned)
|
||||||
|
if err != nil || info.IsDir() {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "page not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if info.Size() > maxPageFileSize {
|
||||||
|
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Data(http.StatusOK, "text/markdown; charset=utf-8", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPages returns available page slugs.
|
||||||
|
// GET /api/v1/pages
|
||||||
|
func (h *PageHandler) ListPages(c *gin.Context) {
|
||||||
|
entries, err := os.ReadDir(h.pagesDir)
|
||||||
|
if err != nil {
|
||||||
|
response.Success(c, []string{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slugs := make([]string, 0, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if strings.HasSuffix(name, ".md") {
|
||||||
|
slugs = append(slugs, strings.TrimSuffix(name, ".md"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.Success(c, slugs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServePageImage serves images from data/pages/{slug}/ directory.
|
||||||
|
// GET /api/v1/pages/:slug/images/*filename
|
||||||
|
// No JWT required (browser img tags can't carry tokens), but visibility is checked.
|
||||||
|
func (h *PageHandler) ServePageImage(c *gin.Context) {
|
||||||
|
slug := c.Param("slug")
|
||||||
|
filename := c.Param("filename")
|
||||||
|
filename = strings.TrimPrefix(filename, "/")
|
||||||
|
|
||||||
|
if !validSlugPattern.MatchString(slug) || len(slug) > 64 {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.checkImageSlugVisibility(c, slug) {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
imagesDir := filepath.Join(h.pagesDir, slug)
|
||||||
|
cleaned, ok := resolvePageImagePath(h.pagesDir, imagesDir, filename)
|
||||||
|
if !ok {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(cleaned)
|
||||||
|
if err != nil || info.IsDir() {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.File(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePageImagePath(pagesDir, imagesDir, filename string) (string, bool) {
|
||||||
|
relPath, ok := cleanPageImageRelativePath(filename)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanedPagesDir := filepath.Clean(pagesDir)
|
||||||
|
cleanedImagesDir := filepath.Clean(imagesDir)
|
||||||
|
cleanedTarget := filepath.Clean(filepath.Join(cleanedImagesDir, relPath))
|
||||||
|
if !isPathWithinBase(cleanedTarget, cleanedImagesDir) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
realPagesDir, err := filepath.EvalSymlinks(cleanedPagesDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
realImagesDir, err := filepath.EvalSymlinks(cleanedImagesDir)
|
||||||
|
if err != nil || !isPathWithinBase(realImagesDir, realPagesDir) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
realTarget, err := filepath.EvalSymlinks(cleanedTarget)
|
||||||
|
if err != nil || !isPathWithinBase(realTarget, realImagesDir) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return realTarget, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanPageImageRelativePath(filename string) (string, bool) {
|
||||||
|
if filename == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(filename, "/") {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
decoded, err := url.PathUnescape(filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if decoded == "" || strings.HasPrefix(decoded, "/") || strings.Contains(decoded, "\\") || strings.ContainsRune(decoded, 0) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, 0)
|
||||||
|
for _, part := range strings.Split(decoded, "/") {
|
||||||
|
switch part {
|
||||||
|
case "", ".":
|
||||||
|
continue
|
||||||
|
case "..":
|
||||||
|
return "", false
|
||||||
|
default:
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath := filepath.Join(parts...)
|
||||||
|
if filepath.IsAbs(relPath) || filepath.VolumeName(relPath) != "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return relPath, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPathWithinBase(path, base string) bool {
|
||||||
|
rel, err := filepath.Rel(filepath.Clean(base), filepath.Clean(path))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rel != "." && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator))
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found).
|
||||||
|
func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) {
|
||||||
|
if h.settingService == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context())
|
||||||
|
if raw == "" || raw == "[]" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
PageSlug string `json:"page_slug"`
|
||||||
|
Visibility string `json:"visibility"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range items {
|
||||||
|
itemSlug := item.PageSlug
|
||||||
|
if itemSlug == "" && strings.HasPrefix(item.URL, "md:") {
|
||||||
|
itemSlug = strings.TrimPrefix(item.URL, "md:")
|
||||||
|
}
|
||||||
|
if itemSlug == slug {
|
||||||
|
return item.Visibility, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSlugVisibility verifies the slug is configured in custom_menu_items
|
||||||
|
// and the authenticated user has permission to view it.
|
||||||
|
func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool {
|
||||||
|
visibility, found := h.findSlugVisibility(c, slug)
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if visibility == "admin" {
|
||||||
|
role, _ := middleware2.GetUserRoleFromContext(c)
|
||||||
|
return role == "admin"
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkImageSlugVisibility checks visibility for image requests (no JWT available).
|
||||||
|
// Only allows user-visible pages; admin-only pages are blocked.
|
||||||
|
func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool {
|
||||||
|
visibility, found := h.findSlugVisibility(c, slug)
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return visibility != "admin"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPageRoutes registers page routes on a router group.
|
||||||
|
func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) {
|
||||||
|
h := NewPageHandler(dataDir, settingService)
|
||||||
|
|
||||||
|
// Authenticated page content (JWT required + visibility check)
|
||||||
|
pages := v1.Group("/pages")
|
||||||
|
pages.Use(jwtAuth)
|
||||||
|
{
|
||||||
|
pages.GET("/:slug", h.GetPageContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Images: no JWT (browser img tags can't carry tokens), visibility check in handler
|
||||||
|
pageImages := v1.Group("/pages")
|
||||||
|
{
|
||||||
|
pageImages.GET("/:slug/images/*filename", h.ServePageImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin-only: list all available pages
|
||||||
|
adminPages := v1.Group("/pages")
|
||||||
|
adminPages.Use(adminAuth)
|
||||||
|
{
|
||||||
|
adminPages.GET("", h.ListPages)
|
||||||
|
}
|
||||||
|
}
|
||||||
102
backend/internal/handler/page_handler_test.go
Normal file
102
backend/internal/handler/page_handler_test.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanPageImageRelativePath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{name: "single filename", in: "logo.png", want: "logo.png", ok: true},
|
||||||
|
{name: "nested path", in: "images/logo.png", want: filepath.Join("images", "logo.png"), ok: true},
|
||||||
|
{name: "dot prefix", in: "./logo.png", want: "logo.png", ok: true},
|
||||||
|
{name: "url escaped slash", in: "images%2Flogo.png", want: filepath.Join("images", "logo.png"), ok: true},
|
||||||
|
{name: "parent traversal", in: "../secret.png", ok: false},
|
||||||
|
{name: "encoded parent traversal", in: "%2e%2e/secret.png", ok: false},
|
||||||
|
{name: "backslash traversal", in: `images\secret.png`, ok: false},
|
||||||
|
{name: "absolute path", in: "/etc/passwd", ok: false},
|
||||||
|
{name: "encoded absolute path", in: "%2fetc/passwd", ok: false},
|
||||||
|
{name: "encoded nul byte", in: "logo.png%00", ok: false},
|
||||||
|
{name: "invalid escape", in: "logo.png%zz", ok: false},
|
||||||
|
{name: "empty path", in: "", ok: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, ok := cleanPageImageRelativePath(tt.in)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Fatalf("ok = %v, want %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("path = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePageImagePath(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
pagesDir := filepath.Join(root, "pages")
|
||||||
|
base := filepath.Join(pagesDir, "guide")
|
||||||
|
if err := os.MkdirAll(filepath.Join(base, "images"), 0755); err != nil {
|
||||||
|
t.Fatalf("create images dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(base, "logo.png"), []byte("fake"), 0644); err != nil {
|
||||||
|
t.Fatalf("create direct image: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(base, "images", "logo.png"), []byte("fake"), 0644); err != nil {
|
||||||
|
t.Fatalf("create image: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := resolvePageImagePath(pagesDir, base, "logo.png")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected direct image path to be accepted")
|
||||||
|
}
|
||||||
|
want := filepath.Join(base, "logo.png")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("path = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok = resolvePageImagePath(pagesDir, base, "images/logo.png")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected nested image path to be accepted")
|
||||||
|
}
|
||||||
|
want = filepath.Join(base, "images", "logo.png")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("path = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := resolvePageImagePath(pagesDir, base, "../guide.md"); ok {
|
||||||
|
t.Fatalf("expected traversal to be rejected, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
pagesDir := filepath.Join(root, "pages")
|
||||||
|
base := filepath.Join(pagesDir, "guide")
|
||||||
|
outside := filepath.Join(root, "outside")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(base, 0755); err != nil {
|
||||||
|
t.Fatalf("create page dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(outside, 0755); err != nil {
|
||||||
|
t.Fatalf("create outside dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(outside, "secret.png"), []byte("secret"), 0644); err != nil {
|
||||||
|
t.Fatalf("create outside file: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Symlink(outside, filepath.Join(base, "images")); err != nil {
|
||||||
|
t.Skipf("symlink not supported: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := resolvePageImagePath(pagesDir, base, "images/secret.png"); ok {
|
||||||
|
t.Fatalf("expected symlink escape to be rejected, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -40,6 +40,11 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
|
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||||
|
LoginAgreementMode: settings.LoginAgreementMode,
|
||||||
|
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||||
|
LoginAgreementRevision: settings.LoginAgreementRevision,
|
||||||
|
LoginAgreementDocuments: publicLoginAgreementDocumentsToDTO(settings.LoginAgreementDocuments),
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
SiteName: settings.SiteName,
|
SiteName: settings.SiteName,
|
||||||
@ -63,6 +68,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||||
|
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||||
|
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||||
BackendModeEnabled: settings.BackendModeEnabled,
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
PaymentEnabled: settings.PaymentEnabled,
|
PaymentEnabled: settings.PaymentEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
@ -77,5 +84,19 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||||
|
|
||||||
AffiliateEnabled: settings.AffiliateEnabled,
|
AffiliateEnabled: settings.AffiliateEnabled,
|
||||||
|
|
||||||
|
RiskControlEnabled: settings.RiskControlEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||||
|
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
result = append(result, dto.LoginAgreementDocument{
|
||||||
|
ID: item.ID,
|
||||||
|
Title: item.Title,
|
||||||
|
ContentMD: item.ContentMD,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
|
|||||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||||
|
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@ -36,6 +36,7 @@ func ProvideAdminHandlers(
|
|||||||
channelHandler *admin.ChannelHandler,
|
channelHandler *admin.ChannelHandler,
|
||||||
channelMonitorHandler *admin.ChannelMonitorHandler,
|
channelMonitorHandler *admin.ChannelMonitorHandler,
|
||||||
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
||||||
|
contentModerationHandler *admin.ContentModerationHandler,
|
||||||
paymentHandler *admin.PaymentHandler,
|
paymentHandler *admin.PaymentHandler,
|
||||||
windsurfHandler *admin.WindsurfHandler,
|
windsurfHandler *admin.WindsurfHandler,
|
||||||
affiliateHandler *admin.AffiliateHandler,
|
affiliateHandler *admin.AffiliateHandler,
|
||||||
@ -68,6 +69,7 @@ func ProvideAdminHandlers(
|
|||||||
Channel: channelHandler,
|
Channel: channelHandler,
|
||||||
ChannelMonitor: channelMonitorHandler,
|
ChannelMonitor: channelMonitorHandler,
|
||||||
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
||||||
|
ContentModeration: contentModerationHandler,
|
||||||
Payment: paymentHandler,
|
Payment: paymentHandler,
|
||||||
Windsurf: windsurfHandler,
|
Windsurf: windsurfHandler,
|
||||||
Affiliate: affiliateHandler,
|
Affiliate: affiliateHandler,
|
||||||
@ -180,6 +182,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewChannelHandler,
|
admin.NewChannelHandler,
|
||||||
admin.NewChannelMonitorHandler,
|
admin.NewChannelMonitorHandler,
|
||||||
admin.NewChannelMonitorRequestTemplateHandler,
|
admin.NewChannelMonitorRequestTemplateHandler,
|
||||||
|
admin.NewContentModerationHandler,
|
||||||
admin.NewPaymentHandler,
|
admin.NewPaymentHandler,
|
||||||
admin.NewAffiliateHandler,
|
admin.NewAffiliateHandler,
|
||||||
|
|
||||||
|
|||||||
@ -317,6 +317,7 @@ const CLICurrentVersion = "2.1.92"
|
|||||||
// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
|
// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
|
||||||
// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
|
// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
|
||||||
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
|
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
|
||||||
|
// - 不默认加入 redact-thinking,避免上游抹除 thinking 内容;客户端显式传入时由合并逻辑保留。
|
||||||
func FullClaudeCodeMimicryBetas() []string {
|
func FullClaudeCodeMimicryBetas() []string {
|
||||||
return []string{
|
return []string{
|
||||||
BetaClaudeCode,
|
BetaClaudeCode,
|
||||||
@ -324,7 +325,6 @@ func FullClaudeCodeMimicryBetas() []string {
|
|||||||
BetaInterleavedThinking,
|
BetaInterleavedThinking,
|
||||||
BetaPromptCachingScope,
|
BetaPromptCachingScope,
|
||||||
BetaEffort,
|
BetaEffort,
|
||||||
BetaRedactThinking,
|
|
||||||
BetaContextManagement,
|
BetaContextManagement,
|
||||||
BetaExtendedCacheTTL,
|
BetaExtendedCacheTTL,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
apikey.FieldID,
|
apikey.FieldID,
|
||||||
apikey.FieldUserID,
|
apikey.FieldUserID,
|
||||||
apikey.FieldGroupID,
|
apikey.FieldGroupID,
|
||||||
|
apikey.FieldName,
|
||||||
apikey.FieldStatus,
|
apikey.FieldStatus,
|
||||||
apikey.FieldIPWhitelist,
|
apikey.FieldIPWhitelist,
|
||||||
apikey.FieldIPBlacklist,
|
apikey.FieldIPBlacklist,
|
||||||
|
|||||||
@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S
|
|||||||
|
|
||||||
got, err := repo.GetByKeyForAuth(ctx, key.Key)
|
got, err := repo.GetByKeyForAuth(ctx, key.Key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, key.Name, got.Name)
|
||||||
require.NotNil(t, got.Group)
|
require.NotNil(t, got.Group)
|
||||||
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
|
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
71
backend/internal/repository/content_moderation_hash_cache.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes"
|
||||||
|
|
||||||
|
type contentModerationHashCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache {
|
||||||
|
return &contentModerationHashCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error {
|
||||||
|
inputHash = strings.TrimSpace(inputHash)
|
||||||
|
if c == nil || c.rdb == nil || inputHash == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||||
|
inputHash = strings.TrimSpace(inputHash)
|
||||||
|
if c == nil || c.rdb == nil || inputHash == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) {
|
||||||
|
inputHash = strings.TrimSpace(inputHash)
|
||||||
|
if c == nil || c.rdb == nil || inputHash == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return deleted > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||||
|
if c == nil || c.rdb == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if deleted == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) {
|
||||||
|
if c == nil || c.rdb == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result()
|
||||||
|
}
|
||||||
274
backend/internal/repository/content_moderation_repo.go
Normal file
274
backend/internal/repository/content_moderation_repo.go
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contentModerationRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository {
|
||||||
|
return &contentModerationRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||||
|
if log == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
categoryScores, err := json.Marshal(log.CategoryScores)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal moderation category scores: %w", err)
|
||||||
|
}
|
||||||
|
thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal moderation thresholds: %w", err)
|
||||||
|
}
|
||||||
|
var userID any
|
||||||
|
if log.UserID != nil {
|
||||||
|
userID = *log.UserID
|
||||||
|
}
|
||||||
|
var apiKeyID any
|
||||||
|
if log.APIKeyID != nil {
|
||||||
|
apiKeyID = *log.APIKeyID
|
||||||
|
}
|
||||||
|
var groupID any
|
||||||
|
if log.GroupID != nil {
|
||||||
|
groupID = *log.GroupID
|
||||||
|
}
|
||||||
|
var latency any
|
||||||
|
if log.UpstreamLatencyMS != nil {
|
||||||
|
latency = *log.UpstreamLatencyMS
|
||||||
|
}
|
||||||
|
err = r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO content_moderation_logs (
|
||||||
|
request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name,
|
||||||
|
endpoint, provider, model, mode, action, flagged, highest_category, highest_score,
|
||||||
|
category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error,
|
||||||
|
violation_count, auto_banned, email_sent, queue_delay_ms
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
|
$8, $9, $10, $11, $12, $13, $14, $15,
|
||||||
|
$16::jsonb, $17::jsonb, $18, $19, $20,
|
||||||
|
$21, $22, $23, $24
|
||||||
|
) RETURNING id, created_at`,
|
||||||
|
log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName,
|
||||||
|
log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore,
|
||||||
|
string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error,
|
||||||
|
log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS),
|
||||||
|
).Scan(&log.ID, &log.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert content moderation log: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||||
|
where, args := buildContentModerationLogWhere(filter)
|
||||||
|
whereSQL := "WHERE " + strings.Join(where, " AND ")
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("count content moderation logs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := filter.Pagination
|
||||||
|
if params.Page <= 0 {
|
||||||
|
params.Page = 1
|
||||||
|
}
|
||||||
|
if params.PageSize <= 0 {
|
||||||
|
params.PageSize = 20
|
||||||
|
}
|
||||||
|
if params.PageSize > 100 {
|
||||||
|
params.PageSize = 100
|
||||||
|
}
|
||||||
|
queryArgs := append([]any{}, args...)
|
||||||
|
queryArgs = append(queryArgs, params.Limit(), params.Offset())
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT
|
||||||
|
l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name,
|
||||||
|
l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score,
|
||||||
|
l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error,
|
||||||
|
l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at
|
||||||
|
FROM content_moderation_logs l
|
||||||
|
LEFT JOIN users u ON u.id = l.user_id `+whereSQL+`
|
||||||
|
ORDER BY l.created_at DESC, l.id DESC
|
||||||
|
LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)),
|
||||||
|
queryArgs...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("list content moderation logs: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]service.ContentModerationLog, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.ContentModerationLog
|
||||||
|
var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64
|
||||||
|
var scoresRaw, thresholdsRaw []byte
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.ID,
|
||||||
|
&item.RequestID,
|
||||||
|
&userID,
|
||||||
|
&item.UserEmail,
|
||||||
|
&apiKeyID,
|
||||||
|
&item.APIKeyName,
|
||||||
|
&groupID,
|
||||||
|
&item.GroupName,
|
||||||
|
&item.Endpoint,
|
||||||
|
&item.Provider,
|
||||||
|
&item.Model,
|
||||||
|
&item.Mode,
|
||||||
|
&item.Action,
|
||||||
|
&item.Flagged,
|
||||||
|
&item.HighestCategory,
|
||||||
|
&item.HighestScore,
|
||||||
|
&scoresRaw,
|
||||||
|
&thresholdsRaw,
|
||||||
|
&item.InputExcerpt,
|
||||||
|
&latency,
|
||||||
|
&item.Error,
|
||||||
|
&item.ViolationCount,
|
||||||
|
&item.AutoBanned,
|
||||||
|
&item.EmailSent,
|
||||||
|
&item.UserStatus,
|
||||||
|
&queueDelay,
|
||||||
|
&item.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("scan content moderation log: %w", err)
|
||||||
|
}
|
||||||
|
if userID.Valid {
|
||||||
|
v := userID.Int64
|
||||||
|
item.UserID = &v
|
||||||
|
}
|
||||||
|
if apiKeyID.Valid {
|
||||||
|
v := apiKeyID.Int64
|
||||||
|
item.APIKeyID = &v
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
item.GroupID = &v
|
||||||
|
}
|
||||||
|
if latency.Valid {
|
||||||
|
v := int(latency.Int64)
|
||||||
|
item.UpstreamLatencyMS = &v
|
||||||
|
}
|
||||||
|
if queueDelay.Valid {
|
||||||
|
v := int(queueDelay.Int64)
|
||||||
|
item.QueueDelayMS = &v
|
||||||
|
}
|
||||||
|
item.CategoryScores = map[string]float64{}
|
||||||
|
_ = json.Unmarshal(scoresRaw, &item.CategoryScores)
|
||||||
|
item.ThresholdSnapshot = map[string]float64{}
|
||||||
|
_ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot)
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err)
|
||||||
|
}
|
||||||
|
return items, paginationResultFromTotal(total, params), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
var count int
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
WITH last_auto_ban AS (
|
||||||
|
SELECT MAX(created_at) AS at
|
||||||
|
FROM content_moderation_logs
|
||||||
|
WHERE user_id = $1 AND auto_banned = TRUE
|
||||||
|
)
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM content_moderation_logs
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND flagged = TRUE
|
||||||
|
AND created_at >= $2
|
||||||
|
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
|
||||||
|
`, userID, since).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("count user content moderation flagged logs: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||||
|
result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()}
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
hitExec, err := r.db.ExecContext(ctx, `
|
||||||
|
DELETE FROM content_moderation_logs
|
||||||
|
WHERE flagged = TRUE AND created_at < $1
|
||||||
|
`, hitBefore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err)
|
||||||
|
}
|
||||||
|
result.DeletedHit, _ = hitExec.RowsAffected()
|
||||||
|
|
||||||
|
nonHitExec, err := r.db.ExecContext(ctx, `
|
||||||
|
DELETE FROM content_moderation_logs
|
||||||
|
WHERE flagged = FALSE AND created_at < $1
|
||||||
|
`, nonHitBefore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err)
|
||||||
|
}
|
||||||
|
result.DeletedNonHit, _ = nonHitExec.RowsAffected()
|
||||||
|
|
||||||
|
result.FinishedAt = time.Now()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableIntPtr(value *int) any {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) {
|
||||||
|
where := []string{"l.id IS NOT NULL"}
|
||||||
|
args := make([]any, 0)
|
||||||
|
add := func(expr string, value any) {
|
||||||
|
args = append(args, value)
|
||||||
|
where = append(where, fmt.Sprintf(expr, len(args)))
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(filter.Result)) {
|
||||||
|
case "hit", "flagged":
|
||||||
|
where = append(where, "l.flagged = TRUE")
|
||||||
|
case "blocked", "block":
|
||||||
|
where = append(where, "l.action = 'block'")
|
||||||
|
case "pass", "allow":
|
||||||
|
where = append(where, "l.flagged = FALSE AND l.error = ''")
|
||||||
|
case "error":
|
||||||
|
where = append(where, "l.error <> ''")
|
||||||
|
}
|
||||||
|
if filter.GroupID != nil {
|
||||||
|
add("l.group_id = $%d", *filter.GroupID)
|
||||||
|
}
|
||||||
|
if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" {
|
||||||
|
add("l.endpoint = $%d", endpoint)
|
||||||
|
}
|
||||||
|
if search := strings.TrimSpace(filter.Search); search != "" {
|
||||||
|
like := "%" + search + "%"
|
||||||
|
args = append(args, like, like, like, like, like)
|
||||||
|
idx := len(args) - 4
|
||||||
|
where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4))
|
||||||
|
}
|
||||||
|
if filter.From != nil && !filter.From.IsZero() {
|
||||||
|
add("l.created_at >= $%d", *filter.From)
|
||||||
|
}
|
||||||
|
if filter.To != nil && !filter.To.IsZero() {
|
||||||
|
add("l.created_at <= $%d", *filter.To)
|
||||||
|
}
|
||||||
|
return where, args
|
||||||
|
}
|
||||||
@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if value < 0 {
|
||||||
|
value = 0
|
||||||
|
}
|
||||||
|
res, err := r.sql.ExecContext(ctx,
|
||||||
|
"UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||||
|
value, pq.Array(userIDs))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("batch set concurrency: %w", err)
|
||||||
|
}
|
||||||
|
affected, _ := res.RowsAffected()
|
||||||
|
return int(affected), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
res, err := r.sql.ExecContext(ctx,
|
||||||
|
"UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
|
||||||
|
delta, pq.Array(userIDs))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("batch add concurrency: %w", err)
|
||||||
|
}
|
||||||
|
affected, _ := res.RowsAffected()
|
||||||
|
return int(affected), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewChannelRepository,
|
NewChannelRepository,
|
||||||
NewChannelMonitorRepository,
|
NewChannelMonitorRepository,
|
||||||
NewChannelMonitorRequestTemplateRepository,
|
NewChannelMonitorRequestTemplateRepository,
|
||||||
|
NewContentModerationRepository,
|
||||||
NewAffiliateRepository,
|
NewAffiliateRepository,
|
||||||
|
|
||||||
// Cache implementations
|
// Cache implementations
|
||||||
@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewRefreshTokenCache,
|
NewRefreshTokenCache,
|
||||||
NewErrorPassthroughCache,
|
NewErrorPassthroughCache,
|
||||||
NewTLSFingerprintProfileCache,
|
NewTLSFingerprintProfileCache,
|
||||||
|
NewContentModerationHashCache,
|
||||||
|
|
||||||
// Encryptors
|
// Encryptors
|
||||||
NewAESEncryptor,
|
NewAESEncryptor,
|
||||||
|
|||||||
@ -646,12 +646,21 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"registration_email_suffix_whitelist": [],
|
"registration_email_suffix_whitelist": [],
|
||||||
"promo_code_enabled": true,
|
"promo_code_enabled": true,
|
||||||
"password_reset_enabled": false,
|
"password_reset_enabled": false,
|
||||||
"frontend_url": "",
|
"frontend_url": "",
|
||||||
"totp_enabled": false,
|
"totp_enabled": false,
|
||||||
"totp_encryption_key_configured": false,
|
"totp_encryption_key_configured": false,
|
||||||
"smtp_host": "smtp.example.com",
|
"login_agreement_enabled": false,
|
||||||
"smtp_port": 587,
|
"login_agreement_mode": "modal",
|
||||||
"smtp_username": "user",
|
"login_agreement_updated_at": "2026-03-31",
|
||||||
|
"login_agreement_documents": [
|
||||||
|
{"id": "terms", "title": "服务条款", "content_md": ""},
|
||||||
|
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
|
||||||
|
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
|
||||||
|
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
|
||||||
|
],
|
||||||
|
"smtp_host": "smtp.example.com",
|
||||||
|
"smtp_port": 587,
|
||||||
|
"smtp_username": "user",
|
||||||
"smtp_password_configured": true,
|
"smtp_password_configured": true,
|
||||||
"smtp_from_email": "no-reply@example.com",
|
"smtp_from_email": "no-reply@example.com",
|
||||||
"smtp_from_name": "Sub2API",
|
"smtp_from_name": "Sub2API",
|
||||||
@ -685,6 +694,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"oidc_connect_userinfo_email_path": "",
|
"oidc_connect_userinfo_email_path": "",
|
||||||
"oidc_connect_userinfo_id_path": "",
|
"oidc_connect_userinfo_id_path": "",
|
||||||
"oidc_connect_userinfo_username_path": "",
|
"oidc_connect_userinfo_username_path": "",
|
||||||
|
"github_oauth_enabled": false,
|
||||||
|
"github_oauth_client_id": "",
|
||||||
|
"github_oauth_client_secret_configured": false,
|
||||||
|
"github_oauth_redirect_url": "",
|
||||||
|
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||||
|
"google_oauth_enabled": false,
|
||||||
|
"google_oauth_client_id": "",
|
||||||
|
"google_oauth_client_secret_configured": false,
|
||||||
|
"google_oauth_redirect_url": "",
|
||||||
|
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||||
"ops_monitoring_enabled": false,
|
"ops_monitoring_enabled": false,
|
||||||
"ops_realtime_monitoring_enabled": true,
|
"ops_realtime_monitoring_enabled": true,
|
||||||
"ops_query_mode_default": "auto",
|
"ops_query_mode_default": "auto",
|
||||||
@ -700,6 +719,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"auth_source_default_email_subscriptions": [],
|
"auth_source_default_email_subscriptions": [],
|
||||||
"auth_source_default_email_grant_on_signup": false,
|
"auth_source_default_email_grant_on_signup": false,
|
||||||
"auth_source_default_email_grant_on_first_bind": false,
|
"auth_source_default_email_grant_on_first_bind": false,
|
||||||
|
"auth_source_default_github_balance": 0,
|
||||||
|
"auth_source_default_github_concurrency": 5,
|
||||||
|
"auth_source_default_github_subscriptions": [],
|
||||||
|
"auth_source_default_github_grant_on_signup": false,
|
||||||
|
"auth_source_default_github_grant_on_first_bind": false,
|
||||||
|
"auth_source_default_google_balance": 0,
|
||||||
|
"auth_source_default_google_concurrency": 5,
|
||||||
|
"auth_source_default_google_subscriptions": [],
|
||||||
|
"auth_source_default_google_grant_on_signup": false,
|
||||||
|
"auth_source_default_google_grant_on_first_bind": false,
|
||||||
"auth_source_default_linuxdo_balance": 0,
|
"auth_source_default_linuxdo_balance": 0,
|
||||||
"auth_source_default_linuxdo_concurrency": 5,
|
"auth_source_default_linuxdo_concurrency": 5,
|
||||||
"auth_source_default_linuxdo_subscriptions": [],
|
"auth_source_default_linuxdo_subscriptions": [],
|
||||||
@ -792,6 +821,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"channel_monitor_enabled": true,
|
"channel_monitor_enabled": true,
|
||||||
"channel_monitor_default_interval_seconds": 60,
|
"channel_monitor_default_interval_seconds": 60,
|
||||||
"available_channels_enabled": false,
|
"available_channels_enabled": false,
|
||||||
|
"risk_control_enabled": false,
|
||||||
"affiliate_enabled": false,
|
"affiliate_enabled": false,
|
||||||
"wechat_connect_enabled": false,
|
"wechat_connect_enabled": false,
|
||||||
"wechat_connect_app_id": "",
|
"wechat_connect_app_id": "",
|
||||||
@ -859,12 +889,21 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"promo_code_enabled": true,
|
"promo_code_enabled": true,
|
||||||
"password_reset_enabled": false,
|
"password_reset_enabled": false,
|
||||||
"frontend_url": "",
|
"frontend_url": "",
|
||||||
"invitation_code_enabled": false,
|
"invitation_code_enabled": false,
|
||||||
"totp_enabled": false,
|
"totp_enabled": false,
|
||||||
"totp_encryption_key_configured": false,
|
"totp_encryption_key_configured": false,
|
||||||
"smtp_host": "",
|
"login_agreement_enabled": false,
|
||||||
"smtp_port": 587,
|
"login_agreement_mode": "modal",
|
||||||
"smtp_username": "",
|
"login_agreement_updated_at": "2026-03-31",
|
||||||
|
"login_agreement_documents": [
|
||||||
|
{"id": "terms", "title": "服务条款", "content_md": ""},
|
||||||
|
{"id": "usage-policy", "title": "使用政策", "content_md": ""},
|
||||||
|
{"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""},
|
||||||
|
{"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""}
|
||||||
|
],
|
||||||
|
"smtp_host": "",
|
||||||
|
"smtp_port": 587,
|
||||||
|
"smtp_username": "",
|
||||||
"smtp_password_configured": false,
|
"smtp_password_configured": false,
|
||||||
"smtp_from_email": "",
|
"smtp_from_email": "",
|
||||||
"smtp_from_name": "",
|
"smtp_from_name": "",
|
||||||
@ -898,6 +937,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"oidc_connect_userinfo_email_path": "",
|
"oidc_connect_userinfo_email_path": "",
|
||||||
"oidc_connect_userinfo_id_path": "",
|
"oidc_connect_userinfo_id_path": "",
|
||||||
"oidc_connect_userinfo_username_path": "",
|
"oidc_connect_userinfo_username_path": "",
|
||||||
|
"github_oauth_enabled": false,
|
||||||
|
"github_oauth_client_id": "",
|
||||||
|
"github_oauth_client_secret_configured": false,
|
||||||
|
"github_oauth_redirect_url": "",
|
||||||
|
"github_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||||
|
"google_oauth_enabled": false,
|
||||||
|
"google_oauth_client_id": "",
|
||||||
|
"google_oauth_client_secret_configured": false,
|
||||||
|
"google_oauth_redirect_url": "",
|
||||||
|
"google_oauth_frontend_redirect_url": "/auth/oauth/callback",
|
||||||
"site_name": "Sub2API",
|
"site_name": "Sub2API",
|
||||||
"site_logo": "",
|
"site_logo": "",
|
||||||
"site_subtitle": "Subscription to API Conversion Platform",
|
"site_subtitle": "Subscription to API Conversion Platform",
|
||||||
@ -983,6 +1032,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"channel_monitor_enabled": true,
|
"channel_monitor_enabled": true,
|
||||||
"channel_monitor_default_interval_seconds": 60,
|
"channel_monitor_default_interval_seconds": 60,
|
||||||
"available_channels_enabled": false,
|
"available_channels_enabled": false,
|
||||||
|
"risk_control_enabled": false,
|
||||||
"affiliate_enabled": false,
|
"affiliate_enabled": false,
|
||||||
"wechat_connect_enabled": true,
|
"wechat_connect_enabled": true,
|
||||||
"wechat_connect_app_id": "wx-open-config",
|
"wechat_connect_app_id": "wx-open-config",
|
||||||
@ -1005,6 +1055,16 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"auth_source_default_email_subscriptions": [],
|
"auth_source_default_email_subscriptions": [],
|
||||||
"auth_source_default_email_grant_on_signup": false,
|
"auth_source_default_email_grant_on_signup": false,
|
||||||
"auth_source_default_email_grant_on_first_bind": false,
|
"auth_source_default_email_grant_on_first_bind": false,
|
||||||
|
"auth_source_default_github_balance": 0,
|
||||||
|
"auth_source_default_github_concurrency": 5,
|
||||||
|
"auth_source_default_github_subscriptions": [],
|
||||||
|
"auth_source_default_github_grant_on_signup": false,
|
||||||
|
"auth_source_default_github_grant_on_first_bind": false,
|
||||||
|
"auth_source_default_google_balance": 0,
|
||||||
|
"auth_source_default_google_concurrency": 5,
|
||||||
|
"auth_source_default_google_subscriptions": [],
|
||||||
|
"auth_source_default_google_grant_on_signup": false,
|
||||||
|
"auth_source_default_google_grant_on_first_bind": false,
|
||||||
"auth_source_default_linuxdo_balance": 0,
|
"auth_source_default_linuxdo_balance": 0,
|
||||||
"auth_source_default_linuxdo_concurrency": 5,
|
"auth_source_default_linuxdo_concurrency": 5,
|
||||||
"auth_source_default_linuxdo_subscriptions": [],
|
"auth_source_default_linuxdo_subscriptions": [],
|
||||||
@ -1123,7 +1183,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
|
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
|
||||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
|
|
||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
@ -1294,6 +1354,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
return false, errors.New("not implemented")
|
return false, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i
|
|||||||
panic("unexpected UpdateConcurrency call")
|
panic("unexpected UpdateConcurrency call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
panic("unexpected ExistsByEmail call")
|
panic("unexpected ExistsByEmail call")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool {
|
|||||||
"/auth/oauth/wechat/callback",
|
"/auth/oauth/wechat/callback",
|
||||||
"/auth/oauth/wechat/payment/callback",
|
"/auth/oauth/wechat/payment/callback",
|
||||||
"/auth/oauth/oidc/callback",
|
"/auth/oauth/oidc/callback",
|
||||||
|
"/auth/oauth/github/callback",
|
||||||
|
"/auth/oauth/google/callback",
|
||||||
"/auth/oauth/linuxdo/complete-registration",
|
"/auth/oauth/linuxdo/complete-registration",
|
||||||
"/auth/oauth/wechat/complete-registration",
|
"/auth/oauth/wechat/complete-registration",
|
||||||
"/auth/oauth/oidc/complete-registration",
|
"/auth/oauth/oidc/complete-registration",
|
||||||
|
|||||||
@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) {
|
|||||||
path: "/api/v1/auth/oauth/oidc/callback",
|
path: "/api/v1/auth/oauth/oidc/callback",
|
||||||
wantStatus: http.StatusOK,
|
wantStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_blocks_github_oauth_start",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/oauth/github/start",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_github_oauth_callback",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/oauth/github/callback",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_blocks_google_oauth_start",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/oauth/google/start",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_google_oauth_callback",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/oauth/google/callback",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "enabled_allows_oauth_pending_exchange",
|
name: "enabled_allows_oauth_pending_exchange",
|
||||||
enabled: "true",
|
enabled: "true",
|
||||||
|
|||||||
@ -120,4 +120,6 @@ func registerRoutes(
|
|||||||
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster)
|
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster)
|
||||||
|
|
||||||
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
||||||
|
|
||||||
|
handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -95,11 +95,28 @@ func RegisterAdminRoutes(
|
|||||||
// 渠道监控
|
// 渠道监控
|
||||||
registerChannelMonitorRoutes(admin, h)
|
registerChannelMonitorRoutes(admin, h)
|
||||||
|
|
||||||
|
// 风控中心
|
||||||
|
registerContentModerationRoutes(admin, h)
|
||||||
|
|
||||||
// 邀请返利(专属用户管理)
|
// 邀请返利(专属用户管理)
|
||||||
registerAffiliateRoutes(admin, h)
|
registerAffiliateRoutes(admin, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
risk := admin.Group("/risk-control")
|
||||||
|
{
|
||||||
|
risk.GET("/config", h.Admin.ContentModeration.GetConfig)
|
||||||
|
risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig)
|
||||||
|
risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys)
|
||||||
|
risk.GET("/status", h.Admin.ContentModeration.GetStatus)
|
||||||
|
risk.GET("/logs", h.Admin.ContentModeration.ListLogs)
|
||||||
|
risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser)
|
||||||
|
risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash)
|
||||||
|
risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
apiKeys := admin.Group("/api-keys")
|
apiKeys := admin.Group("/api-keys")
|
||||||
{
|
{
|
||||||
@ -234,6 +251,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
|
||||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||||
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||||
|
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
|
||||||
|
|
||||||
// User attribute values
|
// User attribute values
|
||||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||||
@ -270,6 +288,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||||
accounts.POST("", h.Admin.Account.Create)
|
accounts.POST("", h.Admin.Account.Create)
|
||||||
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||||
|
accounts.POST("/import/codex-session", h.Admin.Account.ImportCodexSession)
|
||||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||||
|
|||||||
@ -63,6 +63,22 @@ func RegisterAuthRoutes(
|
|||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
}), h.Auth.ResetPassword)
|
}), h.Auth.ResetPassword)
|
||||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||||
|
auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart)
|
||||||
|
auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback)
|
||||||
|
auth.POST("/oauth/github/complete-registration",
|
||||||
|
rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}),
|
||||||
|
h.Auth.CompleteGitHubOAuthRegistration,
|
||||||
|
)
|
||||||
|
auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart)
|
||||||
|
auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback)
|
||||||
|
auth.POST("/oauth/google/complete-registration",
|
||||||
|
rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}),
|
||||||
|
h.Auth.CompleteGoogleOAuthRegistration,
|
||||||
|
)
|
||||||
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
|
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
|
||||||
query := c.Request.URL.Query()
|
query := c.Request.URL.Query()
|
||||||
query.Set("intent", "bind_current_user")
|
query.Set("intent", "bind_current_user")
|
||||||
|
|||||||
@ -33,6 +33,7 @@ type AdminService interface {
|
|||||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||||
DeleteUser(ctx context.Context, id int64) error
|
DeleteUser(ctx context.Context, id int64) error
|
||||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||||
|
BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error)
|
||||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||||
@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) {
|
||||||
|
cleaned := make([]int64, 0, len(userIDs))
|
||||||
|
for _, uid := range userIDs {
|
||||||
|
if uid > 0 {
|
||||||
|
cleaned = append(cleaned, uid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cleaned) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var affected int
|
||||||
|
var err error
|
||||||
|
switch mode {
|
||||||
|
case "set":
|
||||||
|
affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value)
|
||||||
|
case "add":
|
||||||
|
affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value)
|
||||||
|
default:
|
||||||
|
return 0, errors.New("invalid mode: must be 'set' or 'add'")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
for _, uid := range cleaned {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return affected, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
|
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float
|
|||||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i
|
|||||||
panic("unexpected UpdateConcurrency call")
|
panic("unexpected UpdateConcurrency call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
if s.existsErr != nil {
|
if s.existsErr != nil {
|
||||||
return false, s.existsErr
|
return false, s.existsErr
|
||||||
|
|||||||
@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64)
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct {
|
|||||||
APIKeyID int64 `json:"api_key_id"`
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/dgraph-io/ristretto"
|
"github.com/dgraph-io/ristretto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
|
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
|
||||||
|
|
||||||
type apiKeyAuthCacheConfig struct {
|
type apiKeyAuthCacheConfig struct {
|
||||||
l1Size int
|
l1Size int
|
||||||
@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
|
|||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
UserID: apiKey.UserID,
|
UserID: apiKey.UserID,
|
||||||
GroupID: apiKey.GroupID,
|
GroupID: apiKey.GroupID,
|
||||||
|
Name: apiKey.Name,
|
||||||
Status: apiKey.Status,
|
Status: apiKey.Status,
|
||||||
IPWhitelist: apiKey.IPWhitelist,
|
IPWhitelist: apiKey.IPWhitelist,
|
||||||
IPBlacklist: apiKey.IPBlacklist,
|
IPBlacklist: apiKey.IPBlacklist,
|
||||||
@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
UserID: snapshot.UserID,
|
UserID: snapshot.UserID,
|
||||||
GroupID: snapshot.GroupID,
|
GroupID: snapshot.GroupID,
|
||||||
Key: key,
|
Key: key,
|
||||||
|
Name: snapshot.Name,
|
||||||
Status: snapshot.Status,
|
Status: snapshot.Status,
|
||||||
IPWhitelist: snapshot.IPWhitelist,
|
IPWhitelist: snapshot.IPWhitelist,
|
||||||
IPBlacklist: snapshot.IPBlacklist,
|
IPBlacklist: snapshot.IPBlacklist,
|
||||||
|
|||||||
@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
|||||||
UserID: 2,
|
UserID: 2,
|
||||||
GroupID: &groupID,
|
GroupID: &groupID,
|
||||||
Key: "k-roundtrip",
|
Key: "k-roundtrip",
|
||||||
|
Name: "Audit Key",
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
User: &User{
|
User: &User{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
|||||||
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
||||||
|
|
||||||
require.NotNil(t, roundTrip)
|
require.NotNil(t, roundTrip)
|
||||||
|
require.Equal(t, apiKey.Name, roundTrip.Name)
|
||||||
require.NotNil(t, roundTrip.Group)
|
require.NotNil(t, roundTrip.Group)
|
||||||
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
|
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
274
backend/internal/service/auth_email_oauth_auto.go
Normal file
274
backend/internal/service/auth_email_oauth_auto.go
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/mail"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmailOAuthIdentityInput struct {
|
||||||
|
ProviderType string
|
||||||
|
ProviderKey string
|
||||||
|
ProviderSubject string
|
||||||
|
Email string
|
||||||
|
EmailVerified bool
|
||||||
|
Username string
|
||||||
|
DisplayName string
|
||||||
|
AvatarURL string
|
||||||
|
UpstreamMetadata map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) {
|
||||||
|
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation(
|
||||||
|
ctx context.Context,
|
||||||
|
input EmailOAuthIdentityInput,
|
||||||
|
invitationCode string,
|
||||||
|
affiliateCode string,
|
||||||
|
) (*TokenPair, *User, error) {
|
||||||
|
return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
input EmailOAuthIdentityInput,
|
||||||
|
invitationCode string,
|
||||||
|
affiliateCode string,
|
||||||
|
) (*TokenPair, *User, error) {
|
||||||
|
if s == nil || s.userRepo == nil || s.entClient == nil {
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||||
|
if providerType != "github" && providerType != "google" {
|
||||||
|
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||||
|
if providerKey == "" {
|
||||||
|
providerKey = providerType
|
||||||
|
}
|
||||||
|
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||||
|
if providerSubject == "" {
|
||||||
|
return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing")
|
||||||
|
}
|
||||||
|
if !input.EmailVerified {
|
||||||
|
return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified")
|
||||||
|
}
|
||||||
|
|
||||||
|
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||||
|
if email == "" || len(email) > 255 {
|
||||||
|
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||||
|
}
|
||||||
|
if _, err := mail.ParseAddress(email); err != nil {
|
||||||
|
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||||
|
}
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return nil, nil, ErrEmailReserved
|
||||||
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||||
|
return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||||
|
}
|
||||||
|
|
||||||
|
user := identityUser
|
||||||
|
created := false
|
||||||
|
if user == nil {
|
||||||
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
created = true
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err)
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.IsActive() {
|
||||||
|
return nil, nil, ErrUserNotActive
|
||||||
|
}
|
||||||
|
if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{
|
||||||
|
ProviderType: providerType,
|
||||||
|
ProviderKey: providerKey,
|
||||||
|
ProviderSubject: providerSubject,
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: input.EmailVerified,
|
||||||
|
Username: input.Username,
|
||||||
|
DisplayName: input.DisplayName,
|
||||||
|
AvatarURL: input.AvatarURL,
|
||||||
|
UpstreamMetadata: input.UpstreamMetadata,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Username == "" && strings.TrimSpace(input.Username) != "" {
|
||||||
|
user.Username = strings.TrimSpace(input.Username)
|
||||||
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !created {
|
||||||
|
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.RecordSuccessfulLogin(ctx, user.ID)
|
||||||
|
|
||||||
|
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||||
|
}
|
||||||
|
return tokenPair, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) {
|
||||||
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
|
return nil, ErrRegDisabled
|
||||||
|
}
|
||||||
|
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrInvitationCodeRequired) {
|
||||||
|
return nil, ErrOAuthInvitationRequired
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
randomPassword, err := randomHexString(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
hashedPassword, err := s.HashPassword(randomPassword)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
grantPlan := s.resolveSignupGrantPlan(ctx, providerType)
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
user := &User{
|
||||||
|
Email: email,
|
||||||
|
Username: strings.TrimSpace(username),
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: grantPlan.Balance,
|
||||||
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
|
Status: StatusActive,
|
||||||
|
SignupSource: providerType,
|
||||||
|
}
|
||||||
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||||
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
existing, loadErr := s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if loadErr != nil {
|
||||||
|
return nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
return nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
s.postAuthUserBootstrap(ctx, user, providerType, false)
|
||||||
|
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||||
|
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||||
|
if invitationRedeemCode != nil {
|
||||||
|
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||||
|
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode)
|
||||||
|
return nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) {
|
||||||
|
identity, err := s.entClient.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderKeyEQ(providerKey),
|
||||||
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
user, err := s.userRepo.GetByID(ctx, identity.UserID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error {
|
||||||
|
metadata := map[string]any{
|
||||||
|
"email": strings.TrimSpace(strings.ToLower(input.Email)),
|
||||||
|
"email_verified": input.EmailVerified,
|
||||||
|
}
|
||||||
|
for key, value := range input.UpstreamMetadata {
|
||||||
|
metadata[key] = value
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.Username) != "" {
|
||||||
|
metadata["username"] = strings.TrimSpace(input.Username)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.DisplayName) != "" {
|
||||||
|
metadata["display_name"] = strings.TrimSpace(input.DisplayName)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.AvatarURL) != "" {
|
||||||
|
metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||||
|
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||||
|
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||||
|
identity, err := s.entClient.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderKeyEQ(providerKey),
|
||||||
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
if identity != nil {
|
||||||
|
if identity.UserID != userID {
|
||||||
|
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
|
}
|
||||||
|
_, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID).
|
||||||
|
SetMetadata(metadata).
|
||||||
|
Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.entClient.AuthIdentity.Create().
|
||||||
|
SetUserID(userID).
|
||||||
|
SetProviderType(providerType).
|
||||||
|
SetProviderKey(providerKey).
|
||||||
|
SetProviderSubject(providerSubject).
|
||||||
|
SetMetadata(metadata).
|
||||||
|
Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOAuthSignupSource(signupSource string) string {
|
func normalizeOAuthSignupSource(signupSource string) string {
|
||||||
@ -17,7 +18,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
|||||||
switch signupSource {
|
switch signupSource {
|
||||||
case "", "email":
|
case "", "email":
|
||||||
return "email"
|
return "email"
|
||||||
case "linuxdo", "wechat", "oidc":
|
case "linuxdo", "wechat", "oidc", "github", "google":
|
||||||
return signupSource
|
return signupSource
|
||||||
default:
|
default:
|
||||||
return "email"
|
return "email"
|
||||||
@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
|||||||
return tokenPair, user, nil
|
return tokenPair, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
|
||||||
|
// provider that has already returned a verified email address.
|
||||||
|
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
|
||||||
|
ctx context.Context,
|
||||||
|
email string,
|
||||||
|
password string,
|
||||||
|
invitationCode string,
|
||||||
|
signupSource string,
|
||||||
|
) (*TokenPair, *User, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
|
return nil, nil, ErrRegDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
email = strings.TrimSpace(strings.ToLower(email))
|
||||||
|
if email == "" || len(email) > 255 {
|
||||||
|
return nil, nil, ErrEmailVerifyRequired
|
||||||
|
}
|
||||||
|
if _, err := mail.ParseAddress(email); err != nil {
|
||||||
|
return nil, nil, ErrEmailVerifyRequired
|
||||||
|
}
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return nil, nil, ErrEmailReserved
|
||||||
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(password) == "" {
|
||||||
|
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
|
||||||
|
}
|
||||||
|
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
if existsEmail {
|
||||||
|
return nil, nil, ErrEmailExists
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := s.HashPassword(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||||
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
user := &User{
|
||||||
|
Email: email,
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: grantPlan.Balance,
|
||||||
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
|
Status: StatusActive,
|
||||||
|
SignupSource: signupSource,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||||
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
return nil, nil, ErrEmailExists
|
||||||
|
}
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||||
|
if err != nil {
|
||||||
|
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
|
||||||
|
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||||
|
}
|
||||||
|
return tokenPair, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
||||||
// only after the pending OAuth flow has fully reached its last reversible step.
|
// only after the pending OAuth flow has fully reached its last reversible step.
|
||||||
func (s *AuthService) FinalizeOAuthEmailAccount(
|
func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||||
|
|||||||
@ -229,6 +229,67 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
|
|||||||
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
signupSource string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "github",
|
||||||
|
email: "github@example.com",
|
||||||
|
signupSource: " GitHub ",
|
||||||
|
want: "github",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "google",
|
||||||
|
email: "google@example.com",
|
||||||
|
signupSource: " Google ",
|
||||||
|
want: "google",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
userRepo := &userRepoStub{nextID: 43}
|
||||||
|
emailCache := &emailCacheStub{
|
||||||
|
data: &VerificationCodeData{
|
||||||
|
Code: "246810",
|
||||||
|
Attempts: 0,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authService := newOAuthEmailFlowAuthService(
|
||||||
|
userRepo,
|
||||||
|
&redeemCodeRepoStub{},
|
||||||
|
&refreshTokenCacheStub{},
|
||||||
|
map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
|
},
|
||||||
|
emailCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||||
|
context.Background(),
|
||||||
|
tt.email,
|
||||||
|
"secret-123",
|
||||||
|
"246810",
|
||||||
|
"",
|
||||||
|
tt.signupSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, tokenPair)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Len(t, userRepo.created, 1)
|
||||||
|
require.Equal(t, tt.want, userRepo.created[0].SignupSource)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
||||||
userRepo := &userRepoStub{nextID: 43}
|
userRepo := &userRepoStub{nextID: 43}
|
||||||
emailCache := &emailCacheStub{
|
emailCache := &emailCacheStub{
|
||||||
@ -256,7 +317,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
|
|||||||
"secret-123",
|
"secret-123",
|
||||||
"246810",
|
"246810",
|
||||||
"",
|
"",
|
||||||
"github",
|
"unknown-provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
|||||||
return defaults.OIDC, true
|
return defaults.OIDC, true
|
||||||
case "wechat":
|
case "wechat":
|
||||||
return defaults.WeChat, true
|
return defaults.WeChat, true
|
||||||
|
case "github":
|
||||||
|
return defaults.GitHub, true
|
||||||
|
case "google":
|
||||||
|
return defaults.Google, true
|
||||||
default:
|
default:
|
||||||
return ProviderDefaultGrantSettings{}, false
|
return ProviderDefaultGrantSettings{}, false
|
||||||
}
|
}
|
||||||
|
|||||||
@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
|
|||||||
return ok, nil
|
return ok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
64
backend/internal/service/codex_image_generation_bridge.go
Normal file
64
backend/internal/service/codex_image_generation_bridge.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
const featureKeyCodexImageGenerationBridge = "codex_image_generation_bridge"
|
||||||
|
|
||||||
|
func boolOverridePtr(v bool) *bool {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolOverrideFromMap(values map[string]any, keys ...string) *bool {
|
||||||
|
if values == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
if v, ok := values[key].(bool); ok {
|
||||||
|
return boolOverridePtr(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func platformBoolOverride(values map[string]any, key string, platform string) *bool {
|
||||||
|
if values == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v, ok := values[key].(bool); ok {
|
||||||
|
return boolOverridePtr(v)
|
||||||
|
}
|
||||||
|
raw, ok := values[key].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
platform = strings.TrimSpace(platform)
|
||||||
|
if platform == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v, ok := raw[platform].(bool); ok {
|
||||||
|
return boolOverridePtr(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexImageGenerationBridgeOverride returns the channel-level override for Codex
|
||||||
|
// image_generation bridge injection. Nil means follow the global/account policy.
|
||||||
|
func (c *Channel) CodexImageGenerationBridgeOverride(platform string) *bool {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return platformBoolOverride(c.FeaturesConfig, featureKeyCodexImageGenerationBridge, platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexImageGenerationBridgeOverride returns the account-level override for Codex
|
||||||
|
// image_generation bridge injection. Nil means follow the channel/global policy.
|
||||||
|
func (a *Account) CodexImageGenerationBridgeOverride() *bool {
|
||||||
|
if a == nil || a.Platform != PlatformOpenAI || a.Extra == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if override := boolOverrideFromMap(a.Extra, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled"); override != nil {
|
||||||
|
return override
|
||||||
|
}
|
||||||
|
openaiConfig, _ := a.Extra[PlatformOpenAI].(map[string]any)
|
||||||
|
return boolOverrideFromMap(openaiConfig, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled")
|
||||||
|
}
|
||||||
2048
backend/internal/service/content_moderation.go
Normal file
2048
backend/internal/service/content_moderation.go
Normal file
File diff suppressed because it is too large
Load Diff
117
backend/internal/service/content_moderation_email.go
Normal file
117
backend/internal/service/content_moderation_email.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
|
||||||
|
if log == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
userName := strings.TrimSpace(log.UserEmail)
|
||||||
|
if userName == "" && log.UserID != nil {
|
||||||
|
userName = fmt.Sprintf("UID %d", *log.UserID)
|
||||||
|
}
|
||||||
|
threshold := cfg.BanThreshold
|
||||||
|
if threshold <= 0 {
|
||||||
|
threshold = defaultContentModerationBanThreshold
|
||||||
|
}
|
||||||
|
statusBlock := ""
|
||||||
|
if log.AutoBanned {
|
||||||
|
statusBlock = `<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>`
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`<!doctype html>
|
||||||
|
<html>
|
||||||
|
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
|
||||||
|
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
|
||||||
|
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
|
||||||
|
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
|
||||||
|
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 风控提醒</div>
|
||||||
|
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户触发内容审计规则</h1>
|
||||||
|
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的 API 请求在内容审计中触发平台风控策略。详情如下。</p>
|
||||||
|
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
|
||||||
|
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">触发详情</h2>
|
||||||
|
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d)</td></tr>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
%s
|
||||||
|
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`,
|
||||||
|
html.EscapeString(userName),
|
||||||
|
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
|
||||||
|
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
|
||||||
|
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
|
||||||
|
log.HighestScore,
|
||||||
|
log.ViolationCount,
|
||||||
|
threshold,
|
||||||
|
statusBlock,
|
||||||
|
html.EscapeString(siteName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string {
|
||||||
|
if log == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
userName := strings.TrimSpace(log.UserEmail)
|
||||||
|
if userName == "" && log.UserID != nil {
|
||||||
|
userName = fmt.Sprintf("UID %d", *log.UserID)
|
||||||
|
}
|
||||||
|
threshold := cfg.BanThreshold
|
||||||
|
if threshold <= 0 {
|
||||||
|
threshold = defaultContentModerationBanThreshold
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`<!doctype html>
|
||||||
|
<html>
|
||||||
|
<body style="margin:0;padding:0;background:#f5f6fb;color:#222;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Arial,sans-serif;">
|
||||||
|
<div style="max-width:680px;margin:0 auto;padding:32px 20px;">
|
||||||
|
<div style="height:8px;background:#ef4444;border-radius:14px 14px 0 0;"></div>
|
||||||
|
<div style="background:#fff;border-radius:0 0 14px 14px;padding:40px 48px;box-shadow:0 8px 28px rgba(15,23,42,.08);">
|
||||||
|
<div style="letter-spacing:4px;color:#999;font-size:14px;text-transform:uppercase;">Risk Control / 账户封禁</div>
|
||||||
|
<h1 style="margin:20px 0 28px;font-size:30px;line-height:1.25;">账户已被自动禁用</h1>
|
||||||
|
<p style="font-size:17px;line-height:1.9;margin:0 0 24px;">尊敬的用户 <strong>%s</strong>,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。</p>
|
||||||
|
<div style="background:#fff1f2;border:1px solid #fecdd3;border-radius:12px;padding:22px 28px;margin:28px 0;">
|
||||||
|
<h2 style="margin:0 0 18px;color:#b91c1c;font-size:18px;">封禁详情</h2>
|
||||||
|
<table style="width:100%%;border-collapse:collapse;font-size:16px;">
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">封禁时间</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">触发来源</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">内容审核</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">所属分组</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;border-bottom:1px solid #fee2e2;">命中类别</td><td style="padding:12px 0;border-bottom:1px solid #fee2e2;">%s / %.3f</td></tr>
|
||||||
|
<tr><td style="padding:12px 0;color:#888;">累计触发次数</td><td style="padding:12px 0;color:#dc2626;font-weight:700;">%d 次(阈值 %d)</td></tr>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
<div style="margin-top:24px;padding:18px 20px;border-radius:10px;background:#ff3b30;color:#fff;font-size:18px;font-weight:700;text-align:center;line-height:1.6;">账户当前处于封禁状态,所有 API 请求将被拒绝</div>
|
||||||
|
<p style="font-size:15px;line-height:1.8;color:#666;margin-top:24px;">如需申诉或恢复账号,请联系平台管理员处理。</p>
|
||||||
|
<p style="font-size:14px;line-height:1.8;color:#777;margin-top:28px;">此邮件由 %s 自动发送,请勿回复。</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`,
|
||||||
|
html.EscapeString(userName),
|
||||||
|
html.EscapeString(time.Now().Format("2006-01-02 15:04:05")),
|
||||||
|
html.EscapeString(defaultContentModerationString(log.GroupName, "-")),
|
||||||
|
html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")),
|
||||||
|
log.HighestScore,
|
||||||
|
log.ViolationCount,
|
||||||
|
threshold,
|
||||||
|
html.EscapeString(siteName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultContentModerationString(value string, fallback string) string {
|
||||||
|
if strings.TrimSpace(value) == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(value)
|
||||||
|
}
|
||||||
320
backend/internal/service/content_moderation_input.go
Normal file
320
backend/internal/service/content_moderation_input.go
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExtractContentModerationText(protocol string, body []byte) string {
|
||||||
|
return ExtractContentModerationInput(protocol, body).Text
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return ContentModerationInput{}
|
||||||
|
}
|
||||||
|
var parts []string
|
||||||
|
var images []string
|
||||||
|
switch protocol {
|
||||||
|
case ContentModerationProtocolAnthropicMessages:
|
||||||
|
collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images)
|
||||||
|
case ContentModerationProtocolOpenAIChat:
|
||||||
|
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
|
||||||
|
case ContentModerationProtocolOpenAIResponses:
|
||||||
|
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
|
||||||
|
case ContentModerationProtocolGemini:
|
||||||
|
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
|
||||||
|
case ContentModerationProtocolOpenAIImages:
|
||||||
|
addModerationText(&parts, gjson.GetBytes(body, "prompt").String())
|
||||||
|
collectContentValue(gjson.GetBytes(body, "images"), &parts, &images)
|
||||||
|
default:
|
||||||
|
collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images)
|
||||||
|
collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images)
|
||||||
|
collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images)
|
||||||
|
}
|
||||||
|
out := ContentModerationInput{
|
||||||
|
Text: normalizeContentModerationText(strings.Join(parts, "\n")),
|
||||||
|
Images: normalizeModerationImages(images),
|
||||||
|
}
|
||||||
|
out.Normalize()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) {
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var lastParts []string
|
||||||
|
var lastImages []string
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
|
||||||
|
var candidate []string
|
||||||
|
var candidateImages []string
|
||||||
|
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||||
|
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||||
|
lastParts = candidate
|
||||||
|
lastImages = candidateImages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
*parts = append(*parts, lastParts...)
|
||||||
|
*images = append(*images, lastImages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var lastParts []string
|
||||||
|
var lastImages []string
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
|
||||||
|
var candidate []string
|
||||||
|
var candidateImages []string
|
||||||
|
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||||
|
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||||
|
lastParts = candidate
|
||||||
|
lastImages = candidateImages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
*parts = append(*parts, lastParts...)
|
||||||
|
*images = append(*images, lastImages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||||
|
switch {
|
||||||
|
case !value.Exists():
|
||||||
|
return
|
||||||
|
case value.Type == gjson.String:
|
||||||
|
if !isAnthropicSystemReminderText(value.String()) {
|
||||||
|
addModerationText(parts, value.String())
|
||||||
|
}
|
||||||
|
case value.IsArray():
|
||||||
|
value.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
collectAnthropicUserContentValue(item, parts, images)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
case value.IsObject():
|
||||||
|
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
|
||||||
|
switch typ {
|
||||||
|
case "", "text", "input_text", "message":
|
||||||
|
if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) {
|
||||||
|
addModerationText(parts, value.Get("text").String())
|
||||||
|
}
|
||||||
|
if value.Get("content").Exists() {
|
||||||
|
collectAnthropicUserContentValue(value.Get("content"), parts, images)
|
||||||
|
}
|
||||||
|
case "image_url", "input_image", "image":
|
||||||
|
collectContentValue(value, parts, images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAnthropicSystemReminderText(text string) bool {
|
||||||
|
return strings.HasPrefix(strings.TrimSpace(text), "<system-reminder>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) {
|
||||||
|
switch {
|
||||||
|
case !input.Exists():
|
||||||
|
return
|
||||||
|
case input.Type == gjson.String:
|
||||||
|
addModerationText(parts, input.String())
|
||||||
|
case input.IsArray():
|
||||||
|
var last gjson.Result
|
||||||
|
input.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if isResponsesUserTextItem(item) {
|
||||||
|
last = item
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if last.Exists() {
|
||||||
|
collectContentValue(last.Get("content"), parts, images)
|
||||||
|
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
|
||||||
|
collectContentValue(last, parts, images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case input.IsObject():
|
||||||
|
if isResponsesUserTextItem(input) {
|
||||||
|
collectContentValue(input.Get("content"), parts, images)
|
||||||
|
if input.Get("type").String() == "input_text" || input.Get("text").Exists() {
|
||||||
|
collectContentValue(input, parts, images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResponsesUserTextItem(item gjson.Result) bool {
|
||||||
|
role := strings.ToLower(strings.TrimSpace(item.Get("role").String()))
|
||||||
|
if role == "user" {
|
||||||
|
return responseItemHasModerationText(item)
|
||||||
|
}
|
||||||
|
if role != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return responseItemHasModerationText(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseItemHasModerationText(item gjson.Result) bool {
|
||||||
|
var parts []string
|
||||||
|
var images []string
|
||||||
|
collectContentValue(item.Get("content"), &parts, &images)
|
||||||
|
if item.Get("type").String() == "input_text" || item.Get("text").Exists() {
|
||||||
|
collectContentValue(item, &parts, &images)
|
||||||
|
}
|
||||||
|
return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) {
|
||||||
|
if !contents.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var lastParts []string
|
||||||
|
var lastImages []string
|
||||||
|
contents.ForEach(func(_, content gjson.Result) bool {
|
||||||
|
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
|
||||||
|
if role == "" || role == "user" {
|
||||||
|
var candidate []string
|
||||||
|
var candidateImages []string
|
||||||
|
if arr := content.Get("parts"); arr.IsArray() {
|
||||||
|
arr.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
addModerationText(&candidate, part.Get("text").String())
|
||||||
|
addGeminiModerationImage(&candidateImages, part)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||||
|
lastParts = candidate
|
||||||
|
lastImages = candidateImages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
*parts = append(*parts, lastParts...)
|
||||||
|
*images = append(*images, lastImages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||||
|
switch {
|
||||||
|
case !value.Exists():
|
||||||
|
return
|
||||||
|
case value.Type == gjson.String:
|
||||||
|
addModerationText(parts, value.String())
|
||||||
|
case value.IsArray():
|
||||||
|
value.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
collectContentValue(item, parts, images)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
case value.IsObject():
|
||||||
|
typ := strings.ToLower(strings.TrimSpace(value.Get("type").String()))
|
||||||
|
addModerationImage(images, value.Get("image_url.url").String())
|
||||||
|
addModerationImage(images, value.Get("image_url").String())
|
||||||
|
addModerationImage(images, value.Get("url").String())
|
||||||
|
addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String())
|
||||||
|
addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String())
|
||||||
|
addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String())
|
||||||
|
addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String())
|
||||||
|
addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String())
|
||||||
|
addModerationImage(images, value.Get("source.data").String())
|
||||||
|
addModerationImage(images, value.Get("data").String())
|
||||||
|
addModerationImage(images, value.Get("base64").String())
|
||||||
|
switch typ {
|
||||||
|
case "", "text", "input_text", "message":
|
||||||
|
if value.Get("text").Exists() {
|
||||||
|
addModerationText(parts, value.Get("text").String())
|
||||||
|
}
|
||||||
|
if value.Get("content").Exists() {
|
||||||
|
collectContentValue(value.Get("content"), parts, images)
|
||||||
|
}
|
||||||
|
case "image_url", "input_image", "image":
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addGeminiModerationImage(images *[]string, part gjson.Result) {
|
||||||
|
if inlineData := part.Get("inline_data"); inlineData.IsObject() {
|
||||||
|
mimeType := strings.TrimSpace(inlineData.Get("mime_type").String())
|
||||||
|
data := strings.TrimSpace(inlineData.Get("data").String())
|
||||||
|
if mimeType != "" && data != "" {
|
||||||
|
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if inlineData := part.Get("inlineData"); inlineData.IsObject() {
|
||||||
|
mimeType := strings.TrimSpace(inlineData.Get("mimeType").String())
|
||||||
|
data := strings.TrimSpace(inlineData.Get("data").String())
|
||||||
|
if mimeType != "" && data != "" {
|
||||||
|
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addModerationImage(images, part.Get("file_data.file_uri").String())
|
||||||
|
addModerationImage(images, part.Get("fileData.fileUri").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func addModerationImageData(images *[]string, mimeType string, data string) {
|
||||||
|
mimeType = strings.TrimSpace(mimeType)
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
if mimeType == "" || data == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func addModerationImage(images *[]string, image string) {
|
||||||
|
image = strings.TrimSpace(image)
|
||||||
|
if image == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") {
|
||||||
|
*images = append(*images, image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeModerationImages(images []string) []string {
|
||||||
|
out := make([]string, 0, len(images))
|
||||||
|
seen := make(map[string]struct{}, len(images))
|
||||||
|
for _, image := range images {
|
||||||
|
image = strings.TrimSpace(image)
|
||||||
|
if image == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[image]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[image] = struct{}{}
|
||||||
|
out = append(out, image)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func limitContentModerationImages(images []string) []string {
|
||||||
|
if len(images) <= maxContentModerationInputImages {
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(images))))
|
||||||
|
if err != nil {
|
||||||
|
return images[:maxContentModerationInputImages]
|
||||||
|
}
|
||||||
|
return []string{images[int(idx.Int64())]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addModerationText(parts *[]string, text string) {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "<system-reminder>") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*parts = append(*parts, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeContentModerationText(text string) string {
|
||||||
|
return strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||||
|
}
|
||||||
37
backend/internal/service/content_moderation_redact.go
Normal file
37
backend/internal/service/content_moderation_redact.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var contentModerationSecretPatterns = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)\bhttps?://[^\s"'<>,。;、]+`),
|
||||||
|
regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`),
|
||||||
|
regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`),
|
||||||
|
regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`),
|
||||||
|
regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`),
|
||||||
|
regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`),
|
||||||
|
regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`),
|
||||||
|
regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactContentModerationSecrets(text string) string {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
out := text
|
||||||
|
for idx, pattern := range contentModerationSecretPatterns {
|
||||||
|
switch idx {
|
||||||
|
case 1:
|
||||||
|
out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`)
|
||||||
|
case 2:
|
||||||
|
out = pattern.ReplaceAllString(out, `${1}[已脱敏]`)
|
||||||
|
default:
|
||||||
|
out = pattern.ReplaceAllString(out, `[已脱敏]`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
1006
backend/internal/service/content_moderation_test.go
Normal file
1006
backend/internal/service/content_moderation_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -108,6 +108,12 @@ const (
|
|||||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||||
|
SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路
|
||||||
|
SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON)
|
||||||
|
SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款
|
||||||
|
SettingKeyLoginAgreementMode = "login_agreement_mode" // 条款确认展示模式:modal / checkbox
|
||||||
|
SettingKeyLoginAgreementUpdatedAt = "login_agreement_updated_at" // 条款更新日期(展示用)
|
||||||
|
SettingKeyLoginAgreementDocuments = "login_agreement_documents" // 条款文档列表(JSON,Markdown 内容)
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||||
@ -174,6 +180,18 @@ const (
|
|||||||
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
|
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
|
||||||
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
|
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
|
||||||
|
|
||||||
|
// GitHub / Google 邮箱快捷登录设置
|
||||||
|
SettingKeyGitHubOAuthEnabled = "github_oauth_enabled"
|
||||||
|
SettingKeyGitHubOAuthClientID = "github_oauth_client_id"
|
||||||
|
SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret"
|
||||||
|
SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url"
|
||||||
|
SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url"
|
||||||
|
SettingKeyGoogleOAuthEnabled = "google_oauth_enabled"
|
||||||
|
SettingKeyGoogleOAuthClientID = "google_oauth_client_id"
|
||||||
|
SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret"
|
||||||
|
SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url"
|
||||||
|
SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url"
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SettingKeySiteName = "site_name" // 网站名称
|
SettingKeySiteName = "site_name" // 网站名称
|
||||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||||
@ -217,6 +235,16 @@ const (
|
|||||||
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
|
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
|
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
|
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
|
||||||
|
SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
|
||||||
|
SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
|
||||||
|
SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
|
||||||
|
SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
|
||||||
|
SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
|
||||||
|
SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
|
||||||
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
|
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
|
||||||
|
|
||||||
// 管理员 API Key
|
// 管理员 API Key
|
||||||
|
|||||||
@ -124,6 +124,24 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
|||||||
require.Contains(t, got, "fast-mode-2026-02-01")
|
require.Contains(t, got, "fast-mode-2026-02-01")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFullClaudeCodeMimicryBetas_DoesNotDefaultRedactThinking(t *testing.T) {
|
||||||
|
required := claude.FullClaudeCodeMimicryBetas()
|
||||||
|
|
||||||
|
require.NotContains(t, required, claude.BetaRedactThinking)
|
||||||
|
require.Contains(t, required, claude.BetaClaudeCode)
|
||||||
|
require.Contains(t, required, claude.BetaOAuth)
|
||||||
|
require.Contains(t, required, claude.BetaInterleavedThinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAnthropicBetaDropping_PreservesIncomingRedactThinking(t *testing.T) {
|
||||||
|
required := claude.FullClaudeCodeMimicryBetas()
|
||||||
|
incoming := claude.BetaRedactThinking
|
||||||
|
|
||||||
|
got := mergeAnthropicBetaDropping(required, incoming, droppedBetaSet())
|
||||||
|
|
||||||
|
require.Contains(t, got, claude.BetaRedactThinking)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDroppedBetaSet(t *testing.T) {
|
func TestDroppedBetaSet(t *testing.T) {
|
||||||
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
|
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
|
||||||
base := droppedBetaSet()
|
base := droppedBetaSet()
|
||||||
|
|||||||
@ -5445,6 +5445,12 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
if !sawTerminalEvent {
|
if !sawTerminalEvent {
|
||||||
|
if clientDisconnected && streamInterval > 0 {
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) >= streamInterval {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
}
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
|||||||
@ -440,6 +440,21 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont
|
|||||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) isCodexImageGenerationBridgeEnabled(ctx context.Context, account *Account, apiKey *APIKey) bool {
|
||||||
|
if override := account.CodexImageGenerationBridgeOverride(); override != nil {
|
||||||
|
return *override
|
||||||
|
}
|
||||||
|
if s != nil && s.channelService != nil && apiKey != nil && apiKey.GroupID != nil {
|
||||||
|
ch, err := s.channelService.GetChannelForGroup(ctx, *apiKey.GroupID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to resolve codex image generation bridge channel override", "group_id", *apiKey.GroupID, "error", err)
|
||||||
|
} else if override := ch.CodexImageGenerationBridgeOverride(PlatformOpenAI); override != nil {
|
||||||
|
return *override
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s != nil && s.cfg != nil && s.cfg.Gateway.CodexImageGenerationBridgeEnabled
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||||
return false
|
return false
|
||||||
@ -2059,6 +2074,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
if apiKey != nil {
|
if apiKey != nil {
|
||||||
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
|
imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
|
||||||
}
|
}
|
||||||
|
codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey)
|
||||||
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
|
||||||
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
@ -2128,7 +2144,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||||
}
|
}
|
||||||
|
|
||||||
if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) {
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
disablePatch()
|
disablePatch()
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
|
||||||
@ -2139,7 +2155,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
disablePatch()
|
disablePatch()
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||||
}
|
}
|
||||||
if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) {
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
disablePatch()
|
disablePatch()
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
|
||||||
|
|||||||
@ -83,12 +83,14 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
allowImages bool
|
allowImages bool
|
||||||
wantInjected bool
|
bridgeEnabled bool
|
||||||
|
wantInjected bool
|
||||||
}{
|
}{
|
||||||
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
|
{name: "disabled group skips injection", allowImages: false, bridgeEnabled: true, wantInjected: false},
|
||||||
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
|
{name: "enabled group skips injection by default", allowImages: true, bridgeEnabled: false, wantInjected: false},
|
||||||
|
{name: "enabled group injects image tool when bridge enabled", allowImages: true, bridgeEnabled: true, wantInjected: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -101,6 +103,7 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := newOpenAIImageGenerationControlTestService(upstream)
|
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||||
|
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.bridgeEnabled
|
||||||
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
|
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
|
||||||
account := newOpenAIImageGenerationControlTestAccount()
|
account := newOpenAIImageGenerationControlTestAccount()
|
||||||
|
|
||||||
@ -117,6 +120,154 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForward_ExplicitImageToolWorksWithBridgeDisabled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"id":"resp_explicit_image","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||||
|
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
|
||||||
|
account := newOpenAIImageGenerationControlTestAccount()
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"draw","stream":false,"tools":[{"type":"image_generation","format":"jpeg"}]}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
|
||||||
|
require.Equal(t, "jpeg", gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").output_format`).String())
|
||||||
|
require.False(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").format`).Exists())
|
||||||
|
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||||
|
require.NotContains(t, instructions, "image_generation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForward_ChannelBridgeOverrideEnablesCodexInjection(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"id":"resp_channel_bridge","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := newOpenAIImageGenerationControlTestService(upstream)
|
||||||
|
groupID := int64(4242)
|
||||||
|
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, &Channel{
|
||||||
|
ID: 9001,
|
||||||
|
Status: StatusActive,
|
||||||
|
FeaturesConfig: map[string]any{
|
||||||
|
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0")
|
||||||
|
account := newOpenAIImageGenerationControlTestAccount()
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists())
|
||||||
|
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
|
||||||
|
require.Contains(t, instructions, "image_generation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_CodexImageGenerationBridgeOverridePrecedence(t *testing.T) {
|
||||||
|
groupID := int64(4242)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
global bool
|
||||||
|
channel *Channel
|
||||||
|
account *Account
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "global default enables bridge",
|
||||||
|
global: true,
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "channel true overrides disabled global",
|
||||||
|
global: false,
|
||||||
|
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||||
|
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||||
|
}},
|
||||||
|
account: &Account{Platform: PlatformOpenAI},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "channel false overrides enabled global",
|
||||||
|
global: true,
|
||||||
|
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||||
|
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
|
||||||
|
}},
|
||||||
|
account: &Account{Platform: PlatformOpenAI},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "account false overrides channel and global true",
|
||||||
|
global: true,
|
||||||
|
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||||
|
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true},
|
||||||
|
}},
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Extra: map[string]any{featureKeyCodexImageGenerationBridge: false},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested account true overrides channel false",
|
||||||
|
global: false,
|
||||||
|
channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{
|
||||||
|
featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false},
|
||||||
|
}},
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Extra: map[string]any{
|
||||||
|
PlatformOpenAI: map[string]any{"codex_image_generation_bridge_enabled": true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non openai account extra is ignored",
|
||||||
|
global: false,
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Extra: map[string]any{featureKeyCodexImageGenerationBridge: true},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
|
||||||
|
svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.global
|
||||||
|
if tt.channel != nil {
|
||||||
|
svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, tt.channel)
|
||||||
|
}
|
||||||
|
apiKey := &APIKey{GroupID: &groupID}
|
||||||
|
|
||||||
|
got := svc.isCodexImageGenerationBridgeEnabled(context.Background(), tt.account, apiKey)
|
||||||
|
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
|
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@ -180,6 +331,18 @@ func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newOpenAIImageGenerationControlChannelService(groupID int64, ch *Channel) *ChannelService {
|
||||||
|
svc := &ChannelService{}
|
||||||
|
cache := newEmptyChannelCache()
|
||||||
|
if ch != nil {
|
||||||
|
cache.channelByGroupID[groupID] = ch
|
||||||
|
cache.byID[ch.ID] = ch
|
||||||
|
}
|
||||||
|
cache.loadedAt = time.Now()
|
||||||
|
svc.cache.Store(cache)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
|
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|||||||
@ -90,6 +90,69 @@ type OpenAIImagesRequest struct {
|
|||||||
bodyHash string
|
bodyHash string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIImagesRequest) ModerationBody() []byte {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := map[string]any{}
|
||||||
|
if prompt := strings.TrimSpace(r.Prompt); prompt != "" {
|
||||||
|
payload["prompt"] = prompt
|
||||||
|
}
|
||||||
|
images := r.moderationImages()
|
||||||
|
if len(images) > 0 {
|
||||||
|
payload["images"] = images
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIImagesRequest) moderationImages() []map[string]string {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1)
|
||||||
|
for _, imageURL := range r.InputImageURLs {
|
||||||
|
imageURL = strings.TrimSpace(imageURL)
|
||||||
|
if imageURL != "" {
|
||||||
|
images = append(images, map[string]string{"image_url": imageURL})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, upload := range r.Uploads {
|
||||||
|
if dataURL := upload.ModerationDataURL(); dataURL != "" {
|
||||||
|
images = append(images, map[string]string{"image_url": dataURL})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" {
|
||||||
|
images = append(images, map[string]string{"image_url": maskURL})
|
||||||
|
}
|
||||||
|
if r.MaskUpload != nil {
|
||||||
|
if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" {
|
||||||
|
images = append(images, map[string]string{"image_url": dataURL})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u OpenAIImagesUpload) ModerationDataURL() string {
|
||||||
|
if len(u.Data) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
contentType := strings.TrimSpace(u.ContentType)
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = http.DetectContentType(u.Data)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data))
|
||||||
|
}
|
||||||
|
|
||||||
func (r *OpenAIImagesRequest) IsEdits() bool {
|
func (r *OpenAIImagesRequest) IsEdits() bool {
|
||||||
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
|
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
|
||||||
}
|
}
|
||||||
|
|||||||
@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
|
|||||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesEditsEndpoint,
|
||||||
|
Prompt: "replace background",
|
||||||
|
InputImageURLs: []string{"https://example.com/source.png"},
|
||||||
|
MaskImageURL: "https://example.com/mask.png",
|
||||||
|
}
|
||||||
|
|
||||||
|
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
|
||||||
|
|
||||||
|
require.Equal(t, "replace background", input.Text)
|
||||||
|
require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesEditsEndpoint,
|
||||||
|
Prompt: "replace background",
|
||||||
|
Uploads: []OpenAIImagesUpload{{
|
||||||
|
FieldName: "image",
|
||||||
|
FileName: "source.png",
|
||||||
|
ContentType: "image/png",
|
||||||
|
Data: []byte("fake-image-bytes"),
|
||||||
|
}},
|
||||||
|
MaskUpload: &OpenAIImagesUpload{
|
||||||
|
FieldName: "mask",
|
||||||
|
FileName: "mask.png",
|
||||||
|
ContentType: "image/png",
|
||||||
|
Data: []byte("fake-mask-bytes"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody())
|
||||||
|
|
||||||
|
require.Equal(t, "replace background", input.Text)
|
||||||
|
require.Equal(t, []string{
|
||||||
|
"data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==",
|
||||||
|
"data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz",
|
||||||
|
}, input.Images)
|
||||||
|
|
||||||
|
log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "")
|
||||||
|
require.Equal(t, "replace background", log.InputExcerpt)
|
||||||
|
require.NotContains(t, log.InputExcerpt, "ZmFrZS")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -52,3 +52,47 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess
|
|||||||
require.Equal(t, "client-id-1", info.ClientID)
|
require.Equal(t, "client-id-1", info.ClientID)
|
||||||
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) {
|
||||||
|
refresher := NewOpenAITokenRefresher(nil, nil)
|
||||||
|
expiresAt := time.Now().Add(time.Minute).UTC().Format(time.RFC3339)
|
||||||
|
|
||||||
|
withoutRT := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, refresher.NeedsRefresh(withoutRT, 5*time.Minute))
|
||||||
|
|
||||||
|
withRT := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, refresher.NeedsRefresh(withRT, 5*time.Minute))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_NoRefreshTokenExpiredAccessTokenReturnsError(t *testing.T) {
|
||||||
|
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||||
|
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "expired-access-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, token)
|
||||||
|
require.Contains(t, err.Error(), "refresh_token is missing")
|
||||||
|
}
|
||||||
|
|||||||
@ -152,6 +152,12 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
// 2) Refresh if needed (pre-expiry skew).
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
|
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||||
|
if expiresAt != nil && !time.Now().Before(*expiresAt) {
|
||||||
|
return "", errors.New("openai access_token expired and refresh_token is missing")
|
||||||
|
}
|
||||||
|
needsRefresh = false
|
||||||
|
}
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
|
|
||||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
|
|||||||
@ -424,8 +424,9 @@ func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -650,8 +651,9 @@ func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -819,8 +821,9 @@ func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -848,8 +851,9 @@ func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -875,8 +879,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T)
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cacheKey := OpenAITokenCacheKey(account)
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
@ -911,8 +916,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "fallback-token",
|
"access_token": "fallback-token",
|
||||||
"expires_at": expiresAt,
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct {
|
|||||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||||
InitialRequestModel string
|
InitialRequestModel string
|
||||||
BeforeTurn func(turn int) error
|
BeforeTurn func(turn int) error
|
||||||
|
BeforeRequest func(turn int, payload []byte, originalModel string) error
|
||||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
|
if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil {
|
||||||
|
if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
|
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
|
||||||
if err := hooks.BeforeTurn(turn); err != nil {
|
if err := hooks.BeforeTurn(turn); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
if msgType != coderws.MessageText {
|
if msgType != coderws.MessageText {
|
||||||
return payload, nil, nil
|
return payload, nil, nil
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil {
|
||||||
|
turnNo := int(completedTurns.Load()) + 1
|
||||||
|
if turnNo < 2 {
|
||||||
|
turnNo = 2
|
||||||
|
}
|
||||||
|
requestModel := usageMeta.requestModelForFrame(payload)
|
||||||
|
if requestModel == "" {
|
||||||
|
requestModel = capturedSessionModel
|
||||||
|
}
|
||||||
|
if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil {
|
||||||
|
return payload, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||||||
// session.update 修改 session-level model(Realtime /
|
// session.update 修改 session-level model(Realtime /
|
||||||
// Responses WS 协议允许),如果不刷新就会出现
|
// Responses WS 协议允许),如果不刷新就会出现
|
||||||
|
|||||||
@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
|||||||
case redeemActionRedeem:
|
case redeemActionRedeem:
|
||||||
// Code exists but unused — skip creation, proceed to redeem
|
// Code exists but unused — skip creation, proceed to redeem
|
||||||
}
|
}
|
||||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil {
|
||||||
return fmt.Errorf("redeem balance: %w", err)
|
return fmt.Errorf("redeem balance: %w", err)
|
||||||
}
|
}
|
||||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||||
|
|||||||
@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
client,
|
client,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
registry := payment.NewRegistry()
|
registry := payment.NewRegistry()
|
||||||
provider := &paymentOrderLifecycleQueryProvider{
|
provider := &paymentOrderLifecycleQueryProvider{
|
||||||
@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
client,
|
client,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
registry := payment.NewRegistry()
|
registry := payment.NewRegistry()
|
||||||
provider := &paymentOrderLifecycleQueryProvider{
|
provider := &paymentOrderLifecycleQueryProvider{
|
||||||
@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
client,
|
client,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
registry := payment.NewRegistry()
|
registry := payment.NewRegistry()
|
||||||
provider := &paymentOrderLifecycleQueryProvider{
|
provider := &paymentOrderLifecycleQueryProvider{
|
||||||
@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
|
|||||||
nil,
|
nil,
|
||||||
client,
|
client,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
registry := payment.NewRegistry()
|
registry := payment.NewRegistry()
|
||||||
provider := &paymentOrderLifecycleQueryProvider{
|
provider := &paymentOrderLifecycleQueryProvider{
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +29,15 @@ const (
|
|||||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ctxKeySkipRedeemAffiliate struct{}
|
||||||
|
|
||||||
|
// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level
|
||||||
|
// affiliate rebate. Used by payment fulfillment which handles rebate separately
|
||||||
|
// via applyAffiliateRebateForOrder (with audit-log deduplication).
|
||||||
|
func ContextSkipRedeemAffiliate(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true)
|
||||||
|
}
|
||||||
|
|
||||||
// RedeemCache defines cache operations for redeem service
|
// RedeemCache defines cache operations for redeem service
|
||||||
type RedeemCache interface {
|
type RedeemCache interface {
|
||||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
@ -80,6 +90,7 @@ type RedeemService struct {
|
|||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
|
affiliateService *AffiliateService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedeemService 创建兑换码服务实例
|
// NewRedeemService 创建兑换码服务实例
|
||||||
@ -91,6 +102,7 @@ func NewRedeemService(
|
|||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
entClient *dbent.Client,
|
entClient *dbent.Client,
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
|
affiliateService *AffiliateService,
|
||||||
) *RedeemService {
|
) *RedeemService {
|
||||||
return &RedeemService{
|
return &RedeemService{
|
||||||
redeemRepo: redeemRepo,
|
redeemRepo: redeemRepo,
|
||||||
@ -100,6 +112,7 @@ func NewRedeemService(
|
|||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
authCacheInvalidator: authCacheInvalidator,
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
|
affiliateService: affiliateService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
// 事务提交成功后失效缓存
|
// 事务提交成功后失效缓存
|
||||||
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
||||||
|
|
||||||
|
// 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果)
|
||||||
|
if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 {
|
||||||
|
s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value)
|
||||||
|
}
|
||||||
|
|
||||||
// 重新获取更新后的兑换码
|
// 重新获取更新后的兑换码
|
||||||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) {
|
||||||
|
if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.affiliateService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.affiliateService.IsEnabled(ctx) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rebate > 0 {
|
||||||
|
logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取兑换码
|
// GetByID 根据ID获取兑换码
|
||||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -129,6 +130,8 @@ type AuthSourceDefaultSettings struct {
|
|||||||
LinuxDo ProviderDefaultGrantSettings
|
LinuxDo ProviderDefaultGrantSettings
|
||||||
OIDC ProviderDefaultGrantSettings
|
OIDC ProviderDefaultGrantSettings
|
||||||
WeChat ProviderDefaultGrantSettings
|
WeChat ProviderDefaultGrantSettings
|
||||||
|
GitHub ProviderDefaultGrantSettings
|
||||||
|
Google ProviderDefaultGrantSettings
|
||||||
ForceEmailOnThirdPartySignup bool
|
ForceEmailOnThirdPartySignup bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,6 +172,20 @@ var (
|
|||||||
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||||
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||||
}
|
}
|
||||||
|
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||||
|
balance: SettingKeyAuthSourceDefaultGitHubBalance,
|
||||||
|
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||||
|
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||||
|
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||||
|
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||||
|
}
|
||||||
|
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||||
|
balance: SettingKeyAuthSourceDefaultGoogleBalance,
|
||||||
|
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||||
|
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||||
|
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||||
|
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -177,8 +194,151 @@ const (
|
|||||||
defaultWeChatConnectMode = "open"
|
defaultWeChatConnectMode = "open"
|
||||||
defaultWeChatConnectScopes = "snsapi_login"
|
defaultWeChatConnectScopes = "snsapi_login"
|
||||||
defaultWeChatConnectFrontend = "/auth/wechat/callback"
|
defaultWeChatConnectFrontend = "/auth/wechat/callback"
|
||||||
|
defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize"
|
||||||
|
defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token"
|
||||||
|
defaultGitHubOAuthUserInfo = "https://api.github.com/user"
|
||||||
|
defaultGitHubOAuthEmails = "https://api.github.com/user/emails"
|
||||||
|
defaultGitHubOAuthScopes = "read:user user:email"
|
||||||
|
defaultGitHubOAuthFrontend = "/auth/oauth/callback"
|
||||||
|
defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token"
|
||||||
|
defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||||
|
defaultGoogleOAuthScopes = "openid email profile"
|
||||||
|
defaultGoogleOAuthFrontend = "/auth/oauth/callback"
|
||||||
|
defaultLoginAgreementMode = "modal"
|
||||||
|
defaultLoginAgreementDate = "2026-03-31"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func normalizeLoginAgreementMode(raw string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||||
|
case "checkbox":
|
||||||
|
return "checkbox"
|
||||||
|
default:
|
||||||
|
return defaultLoginAgreementMode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultLoginAgreementDocuments() []LoginAgreementDocument {
|
||||||
|
return []LoginAgreementDocument{
|
||||||
|
{
|
||||||
|
ID: "terms",
|
||||||
|
Title: "服务条款",
|
||||||
|
ContentMD: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "usage-policy",
|
||||||
|
Title: "使用政策",
|
||||||
|
ContentMD: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "supported-regions",
|
||||||
|
Title: "支持的国家和地区",
|
||||||
|
ContentMD: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "service-specific-terms",
|
||||||
|
Title: "服务特定条款",
|
||||||
|
ContentMD: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeLoginAgreementDocumentID(raw string) string {
|
||||||
|
raw = strings.ToLower(strings.TrimSpace(raw))
|
||||||
|
var b strings.Builder
|
||||||
|
lastSeparator := false
|
||||||
|
for _, r := range raw {
|
||||||
|
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') {
|
||||||
|
_, _ = b.WriteRune(r)
|
||||||
|
lastSeparator = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r == '-' || r == '_' || r == ' ' || r == '.' || r == '/' {
|
||||||
|
if !lastSeparator && b.Len() > 0 {
|
||||||
|
if r == '_' {
|
||||||
|
_, _ = b.WriteRune('_')
|
||||||
|
} else {
|
||||||
|
_, _ = b.WriteRune('-')
|
||||||
|
}
|
||||||
|
lastSeparator = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Trim(b.String(), "-_")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeLoginAgreementDocuments(docs []LoginAgreementDocument) []LoginAgreementDocument {
|
||||||
|
normalized := make([]LoginAgreementDocument, 0, len(docs))
|
||||||
|
seen := make(map[string]int, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
title := strings.TrimSpace(doc.Title)
|
||||||
|
content := strings.TrimSpace(doc.ContentMD)
|
||||||
|
if title == "" && content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := normalizeLoginAgreementDocumentID(doc.ID)
|
||||||
|
if id == "" {
|
||||||
|
sum := sha256.Sum256([]byte(fmt.Sprintf("%d:%s:%s", i, title, content)))
|
||||||
|
id = hex.EncodeToString(sum[:])[:12]
|
||||||
|
}
|
||||||
|
baseID := id
|
||||||
|
for suffix := 2; seen[id] > 0; suffix++ {
|
||||||
|
id = fmt.Sprintf("%s-%d", baseID, suffix)
|
||||||
|
}
|
||||||
|
seen[id]++
|
||||||
|
normalized = append(normalized, LoginAgreementDocument{
|
||||||
|
ID: id,
|
||||||
|
Title: title,
|
||||||
|
ContentMD: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLoginAgreementDocuments(raw string) []LoginAgreementDocument {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return defaultLoginAgreementDocuments()
|
||||||
|
}
|
||||||
|
var docs []LoginAgreementDocument
|
||||||
|
if err := json.Unmarshal([]byte(raw), &docs); err != nil {
|
||||||
|
return defaultLoginAgreementDocuments()
|
||||||
|
}
|
||||||
|
docs = normalizeLoginAgreementDocuments(docs)
|
||||||
|
if len(docs) == 0 {
|
||||||
|
return defaultLoginAgreementDocuments()
|
||||||
|
}
|
||||||
|
return docs
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalLoginAgreementDocuments(docs []LoginAgreementDocument) (string, error) {
|
||||||
|
normalized := normalizeLoginAgreementDocuments(docs)
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
normalized = defaultLoginAgreementDocuments()
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(normalized)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal login agreement documents: %w", err)
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLoginAgreementRevision(updatedAt string, docs []LoginAgreementDocument) string {
|
||||||
|
normalized := normalizeLoginAgreementDocuments(docs)
|
||||||
|
payload, err := json.Marshal(struct {
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
Documents []LoginAgreementDocument `json:"documents"`
|
||||||
|
}{
|
||||||
|
UpdatedAt: strings.TrimSpace(updatedAt),
|
||||||
|
Documents: normalized,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
payload = []byte(strings.TrimSpace(updatedAt))
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(payload)
|
||||||
|
return hex.EncodeToString(sum[:])[:16]
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeWeChatConnectModeSetting(raw string) string {
|
func normalizeWeChatConnectModeSetting(raw string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||||
case "mp":
|
case "mp":
|
||||||
@ -411,6 +571,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
SettingKeyPasswordResetEnabled,
|
SettingKeyPasswordResetEnabled,
|
||||||
SettingKeyInvitationCodeEnabled,
|
SettingKeyInvitationCodeEnabled,
|
||||||
SettingKeyTotpEnabled,
|
SettingKeyTotpEnabled,
|
||||||
|
SettingKeyLoginAgreementEnabled,
|
||||||
|
SettingKeyLoginAgreementMode,
|
||||||
|
SettingKeyLoginAgreementUpdatedAt,
|
||||||
|
SettingKeyLoginAgreementDocuments,
|
||||||
SettingKeyTurnstileEnabled,
|
SettingKeyTurnstileEnabled,
|
||||||
SettingKeyTurnstileSiteKey,
|
SettingKeyTurnstileSiteKey,
|
||||||
SettingKeySiteName,
|
SettingKeySiteName,
|
||||||
@ -448,6 +612,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
SettingPaymentEnabled,
|
SettingPaymentEnabled,
|
||||||
SettingKeyOIDCConnectEnabled,
|
SettingKeyOIDCConnectEnabled,
|
||||||
SettingKeyOIDCConnectProviderName,
|
SettingKeyOIDCConnectProviderName,
|
||||||
|
SettingKeyGitHubOAuthEnabled,
|
||||||
|
SettingKeyGitHubOAuthClientID,
|
||||||
|
SettingKeyGitHubOAuthClientSecret,
|
||||||
|
SettingKeyGoogleOAuthEnabled,
|
||||||
|
SettingKeyGoogleOAuthClientID,
|
||||||
|
SettingKeyGoogleOAuthClientSecret,
|
||||||
SettingKeyBalanceLowNotifyEnabled,
|
SettingKeyBalanceLowNotifyEnabled,
|
||||||
SettingKeyBalanceLowNotifyThreshold,
|
SettingKeyBalanceLowNotifyThreshold,
|
||||||
SettingKeyBalanceLowNotifyRechargeURL,
|
SettingKeyBalanceLowNotifyRechargeURL,
|
||||||
@ -456,6 +626,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
SettingKeyChannelMonitorDefaultIntervalSeconds,
|
SettingKeyChannelMonitorDefaultIntervalSeconds,
|
||||||
SettingKeyAvailableChannelsEnabled,
|
SettingKeyAvailableChannelsEnabled,
|
||||||
SettingKeyAffiliateEnabled,
|
SettingKeyAffiliateEnabled,
|
||||||
|
SettingKeyRiskControlEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||||
@ -482,6 +653,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
if oidcProviderName == "" {
|
if oidcProviderName == "" {
|
||||||
oidcProviderName = "OIDC"
|
oidcProviderName = "OIDC"
|
||||||
}
|
}
|
||||||
|
gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github")
|
||||||
|
googleEnabled := s.emailOAuthPublicEnabled(settings, "google")
|
||||||
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
|
weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
|
||||||
|
|
||||||
// Password reset requires email verification to be enabled
|
// Password reset requires email verification to be enabled
|
||||||
@ -494,6 +667,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
settings[SettingKeyTableDefaultPageSize],
|
settings[SettingKeyTableDefaultPageSize],
|
||||||
settings[SettingKeyTablePageSizeOptions],
|
settings[SettingKeyTablePageSizeOptions],
|
||||||
)
|
)
|
||||||
|
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
|
||||||
|
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
|
||||||
|
if loginAgreementUpdatedAt == "" {
|
||||||
|
loginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||||
|
}
|
||||||
|
|
||||||
var balanceLowNotifyThreshold float64
|
var balanceLowNotifyThreshold float64
|
||||||
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
|
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
|
||||||
@ -509,6 +687,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
PasswordResetEnabled: passwordResetEnabled,
|
PasswordResetEnabled: passwordResetEnabled,
|
||||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||||
|
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0,
|
||||||
|
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
|
||||||
|
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||||
|
LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments),
|
||||||
|
LoginAgreementDocuments: loginAgreementDocuments,
|
||||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||||
@ -534,6 +717,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
||||||
OIDCOAuthEnabled: oidcEnabled,
|
OIDCOAuthEnabled: oidcEnabled,
|
||||||
OIDCOAuthProviderName: oidcProviderName,
|
OIDCOAuthProviderName: oidcProviderName,
|
||||||
|
GitHubOAuthEnabled: gitHubEnabled,
|
||||||
|
GoogleOAuthEnabled: googleEnabled,
|
||||||
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
|
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
|
||||||
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
|
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
|
||||||
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
|
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
|
||||||
@ -545,6 +730,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
|
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
|
||||||
|
|
||||||
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
|
AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
|
||||||
|
|
||||||
|
RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,43 +834,50 @@ func (s *SettingService) SetVersion(version string) {
|
|||||||
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
|
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
|
||||||
// drift automatically (see setting_service_injection_test.go).
|
// drift automatically (see setting_service_injection_test.go).
|
||||||
type PublicSettingsInjectionPayload struct {
|
type PublicSettingsInjectionPayload struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"`
|
TotpEnabled bool `json:"totp_enabled"`
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
LoginAgreementEnabled bool `json:"login_agreement_enabled"`
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
LoginAgreementMode string `json:"login_agreement_mode"`
|
||||||
SiteName string `json:"site_name"`
|
LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"`
|
||||||
SiteLogo string `json:"site_logo"`
|
LoginAgreementRevision string `json:"login_agreement_revision"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
ContactInfo string `json:"contact_info"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
DocURL string `json:"doc_url"`
|
SiteName string `json:"site_name"`
|
||||||
HomeContent string `json:"home_content"`
|
SiteLogo string `json:"site_logo"`
|
||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
ContactInfo string `json:"contact_info"`
|
||||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
DocURL string `json:"doc_url"`
|
||||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
HomeContent string `json:"home_content"`
|
||||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||||
PaymentEnabled bool `json:"payment_enabled"`
|
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||||
Version string `json:"version"`
|
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
GitHubOAuthEnabled bool `json:"github_oauth_enabled"`
|
||||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
GoogleOAuthEnabled bool `json:"google_oauth_enabled"`
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
|
PaymentEnabled bool `json:"payment_enabled"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||||
|
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||||
|
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||||
|
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||||
|
|
||||||
// Feature flags — MUST match the opt-in/opt-out registry in
|
// Feature flags — MUST match the opt-in/opt-out registry in
|
||||||
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
|
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
|
||||||
@ -692,6 +886,7 @@ type PublicSettingsInjectionPayload struct {
|
|||||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||||
|
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
||||||
@ -710,6 +905,11 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
|
LoginAgreementEnabled: settings.LoginAgreementEnabled,
|
||||||
|
LoginAgreementMode: settings.LoginAgreementMode,
|
||||||
|
LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt,
|
||||||
|
LoginAgreementRevision: settings.LoginAgreementRevision,
|
||||||
|
LoginAgreementDocuments: settings.LoginAgreementDocuments,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
SiteName: settings.SiteName,
|
SiteName: settings.SiteName,
|
||||||
@ -733,6 +933,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||||
|
GitHubOAuthEnabled: settings.GitHubOAuthEnabled,
|
||||||
|
GoogleOAuthEnabled: settings.GoogleOAuthEnabled,
|
||||||
BackendModeEnabled: settings.BackendModeEnabled,
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
PaymentEnabled: settings.PaymentEnabled,
|
PaymentEnabled: settings.PaymentEnabled,
|
||||||
Version: s.version,
|
Version: s.version,
|
||||||
@ -745,6 +947,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||||
AffiliateEnabled: settings.AffiliateEnabled,
|
AffiliateEnabled: settings.AffiliateEnabled,
|
||||||
|
RiskControlEnabled: settings.RiskControlEnabled,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -806,6 +1009,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string
|
|||||||
return openReady || mpReady, openReady, mpReady, mobileReady
|
return openReady || mpReady, openReady, mpReady, mobileReady
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||||
|
case "github":
|
||||||
|
cfg := config.EmailOAuthProviderConfig{
|
||||||
|
AuthorizeURL: defaultGitHubOAuthAuthorize,
|
||||||
|
TokenURL: defaultGitHubOAuthToken,
|
||||||
|
UserInfoURL: defaultGitHubOAuthUserInfo,
|
||||||
|
EmailsURL: defaultGitHubOAuthEmails,
|
||||||
|
Scopes: defaultGitHubOAuthScopes,
|
||||||
|
FrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||||
|
}
|
||||||
|
if s != nil && s.cfg != nil {
|
||||||
|
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth)
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
case "google":
|
||||||
|
cfg := config.EmailOAuthProviderConfig{
|
||||||
|
AuthorizeURL: defaultGoogleOAuthAuthorize,
|
||||||
|
TokenURL: defaultGoogleOAuthToken,
|
||||||
|
UserInfoURL: defaultGoogleOAuthUserInfo,
|
||||||
|
Scopes: defaultGoogleOAuthScopes,
|
||||||
|
FrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||||
|
}
|
||||||
|
if s != nil && s.cfg != nil {
|
||||||
|
cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth)
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
default:
|
||||||
|
return config.EmailOAuthProviderConfig{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig {
|
||||||
|
base.Enabled = override.Enabled
|
||||||
|
if strings.TrimSpace(override.ClientID) != "" {
|
||||||
|
base.ClientID = strings.TrimSpace(override.ClientID)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.ClientSecret) != "" {
|
||||||
|
base.ClientSecret = strings.TrimSpace(override.ClientSecret)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.AuthorizeURL) != "" {
|
||||||
|
base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.TokenURL) != "" {
|
||||||
|
base.TokenURL = strings.TrimSpace(override.TokenURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.UserInfoURL) != "" {
|
||||||
|
base.UserInfoURL = strings.TrimSpace(override.UserInfoURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.EmailsURL) != "" {
|
||||||
|
base.EmailsURL = strings.TrimSpace(override.EmailsURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.Scopes) != "" {
|
||||||
|
base.Scopes = strings.TrimSpace(override.Scopes)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.RedirectURL) != "" {
|
||||||
|
base.RedirectURL = strings.TrimSpace(override.RedirectURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(override.FrontendRedirectURL) != "" {
|
||||||
|
base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL)
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool {
|
||||||
|
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||||
|
return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig {
|
||||||
|
cfg := s.emailOAuthBaseConfig(provider)
|
||||||
|
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||||
|
case "github":
|
||||||
|
if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok {
|
||||||
|
cfg.Enabled = raw == "true"
|
||||||
|
}
|
||||||
|
cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID)
|
||||||
|
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret)
|
||||||
|
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL)
|
||||||
|
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend)
|
||||||
|
case "google":
|
||||||
|
if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok {
|
||||||
|
cfg.Enabled = raw == "true"
|
||||||
|
}
|
||||||
|
cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID)
|
||||||
|
cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret)
|
||||||
|
cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL)
|
||||||
|
cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend)
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
|
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
|
||||||
// array string, returning only items with visibility != "admin".
|
// array string, returning only items with visibility != "admin".
|
||||||
func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
||||||
@ -1052,6 +1347,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
if settings.WeChatConnectFrontendRedirectURL == "" {
|
if settings.WeChatConnectFrontendRedirectURL == "" {
|
||||||
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
|
settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
|
||||||
}
|
}
|
||||||
|
settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL)
|
||||||
|
settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL)
|
||||||
|
if settings.GitHubOAuthFrontendRedirectURL == "" {
|
||||||
|
settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend
|
||||||
|
}
|
||||||
|
settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL)
|
||||||
|
settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL)
|
||||||
|
if settings.GoogleOAuthFrontendRedirectURL == "" {
|
||||||
|
settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend
|
||||||
|
}
|
||||||
|
|
||||||
updates := make(map[string]string)
|
updates := make(map[string]string)
|
||||||
|
|
||||||
@ -1068,6 +1373,19 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
updates[SettingKeyFrontendURL] = settings.FrontendURL
|
updates[SettingKeyFrontendURL] = settings.FrontendURL
|
||||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||||
|
settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode)
|
||||||
|
settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt)
|
||||||
|
if settings.LoginAgreementUpdatedAt == "" {
|
||||||
|
settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||||
|
}
|
||||||
|
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled)
|
||||||
|
updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode
|
||||||
|
updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt
|
||||||
|
updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON
|
||||||
|
|
||||||
// 邮件服务设置(只有非空才更新密码)
|
// 邮件服务设置(只有非空才更新密码)
|
||||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||||
@ -1121,6 +1439,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
|
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GitHub / Google 邮箱快捷登录
|
||||||
|
updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled)
|
||||||
|
updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID)
|
||||||
|
updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL
|
||||||
|
updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL
|
||||||
|
if settings.GitHubOAuthClientSecret != "" {
|
||||||
|
updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret)
|
||||||
|
}
|
||||||
|
updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled)
|
||||||
|
updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID)
|
||||||
|
updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL
|
||||||
|
updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL
|
||||||
|
if settings.GoogleOAuthClientSecret != "" {
|
||||||
|
updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
// WeChat Connect OAuth 登录
|
// WeChat Connect OAuth 登录
|
||||||
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
|
updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
|
||||||
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
|
updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
|
||||||
@ -1232,6 +1566,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
// Affiliate (邀请返利) feature switch
|
// Affiliate (邀请返利) feature switch
|
||||||
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
|
updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
|
||||||
|
|
||||||
|
// 风控中心功能开关
|
||||||
|
updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled)
|
||||||
|
|
||||||
// Claude Code version check
|
// Claude Code version check
|
||||||
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
|
||||||
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
|
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
|
||||||
@ -1273,17 +1610,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
|
|||||||
settings.LinuxDo.Subscriptions,
|
settings.LinuxDo.Subscriptions,
|
||||||
settings.OIDC.Subscriptions,
|
settings.OIDC.Subscriptions,
|
||||||
settings.WeChat.Subscriptions,
|
settings.WeChat.Subscriptions,
|
||||||
|
settings.GitHub.Subscriptions,
|
||||||
|
settings.Google.Subscriptions,
|
||||||
} {
|
} {
|
||||||
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updates := make(map[string]string, 21)
|
updates := make(map[string]string, 31)
|
||||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||||
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
||||||
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
|
||||||
|
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
|
||||||
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
||||||
return updates, nil
|
return updates, nil
|
||||||
}
|
}
|
||||||
@ -1362,6 +1703,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) {
|
||||||
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if provider != "github" && provider != "google" {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found")
|
||||||
|
}
|
||||||
|
keys := []string{
|
||||||
|
SettingKeyGitHubOAuthEnabled,
|
||||||
|
SettingKeyGitHubOAuthClientID,
|
||||||
|
SettingKeyGitHubOAuthClientSecret,
|
||||||
|
SettingKeyGitHubOAuthRedirectURL,
|
||||||
|
SettingKeyGitHubOAuthFrontendRedirectURL,
|
||||||
|
SettingKeyGoogleOAuthEnabled,
|
||||||
|
SettingKeyGoogleOAuthClientID,
|
||||||
|
SettingKeyGoogleOAuthClientSecret,
|
||||||
|
SettingKeyGoogleOAuthRedirectURL,
|
||||||
|
SettingKeyGoogleOAuthFrontendRedirectURL,
|
||||||
|
}
|
||||||
|
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||||
|
if err != nil {
|
||||||
|
return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err)
|
||||||
|
}
|
||||||
|
cfg := s.effectiveEmailOAuthConfig(settings, provider)
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.ClientID) == "" {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||||
|
}
|
||||||
|
for label, rawURL := range map[string]string{
|
||||||
|
"authorize": cfg.AuthorizeURL,
|
||||||
|
"token": cfg.TokenURL,
|
||||||
|
"userinfo": cfg.UserInfoURL,
|
||||||
|
"redirect": cfg.RedirectURL,
|
||||||
|
} {
|
||||||
|
if strings.TrimSpace(rawURL) == "" {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured")
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.EmailsURL) != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
|
||||||
|
return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
// IsRegistrationEnabled 检查是否开放注册
|
// IsRegistrationEnabled 检查是否开放注册
|
||||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||||
@ -1534,6 +1930,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
|
|||||||
return value == "true"
|
return value == "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting.
|
||||||
|
func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems)
|
||||||
|
if err != nil {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
|
// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
|
||||||
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
|
func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
|
||||||
@ -1711,6 +2116,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
|||||||
SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||||
|
SettingKeyAuthSourceDefaultGitHubBalance,
|
||||||
|
SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||||
|
SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||||
|
SettingKeyAuthSourceDefaultGoogleBalance,
|
||||||
|
SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||||
|
SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||||
SettingKeyForceEmailOnThirdPartySignup,
|
SettingKeyForceEmailOnThirdPartySignup,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1724,6 +2139,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
|||||||
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
|
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
|
||||||
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
|
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
|
||||||
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
|
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
|
||||||
|
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
|
||||||
|
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
|
||||||
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -1793,6 +2210,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
|
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(defaultLoginAgreementDocuments())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 初始化默认设置
|
// 初始化默认设置
|
||||||
defaults := map[string]string{
|
defaults := map[string]string{
|
||||||
@ -1800,6 +2221,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyEmailVerifyEnabled: "false",
|
SettingKeyEmailVerifyEnabled: "false",
|
||||||
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||||
|
SettingKeyLoginAgreementEnabled: "false",
|
||||||
|
SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
|
||||||
|
SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
|
||||||
|
SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
|
||||||
SettingKeySiteName: "Sub2API",
|
SettingKeySiteName: "Sub2API",
|
||||||
SettingKeySiteLogo: "",
|
SettingKeySiteLogo: "",
|
||||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||||
@ -1824,6 +2249,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyWeChatConnectScopes: "snsapi_login",
|
SettingKeyWeChatConnectScopes: "snsapi_login",
|
||||||
SettingKeyWeChatConnectRedirectURL: "",
|
SettingKeyWeChatConnectRedirectURL: "",
|
||||||
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
|
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
|
||||||
|
SettingKeyGitHubOAuthEnabled: "false",
|
||||||
|
SettingKeyGitHubOAuthClientID: "",
|
||||||
|
SettingKeyGitHubOAuthClientSecret: "",
|
||||||
|
SettingKeyGitHubOAuthRedirectURL: "",
|
||||||
|
SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
|
||||||
|
SettingKeyGoogleOAuthEnabled: "false",
|
||||||
|
SettingKeyGoogleOAuthClientID: "",
|
||||||
|
SettingKeyGoogleOAuthClientSecret: "",
|
||||||
|
SettingKeyGoogleOAuthRedirectURL: "",
|
||||||
|
SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
|
||||||
SettingKeyOIDCConnectEnabled: "false",
|
SettingKeyOIDCConnectEnabled: "false",
|
||||||
SettingKeyOIDCConnectProviderName: "OIDC",
|
SettingKeyOIDCConnectProviderName: "OIDC",
|
||||||
SettingKeyOIDCConnectClientID: "",
|
SettingKeyOIDCConnectClientID: "",
|
||||||
@ -1874,6 +2309,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
|
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
|
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
|
||||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
|
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
|
||||||
|
SettingKeyAuthSourceDefaultGitHubBalance: "0",
|
||||||
|
SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
|
||||||
|
SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
|
||||||
|
SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
|
||||||
|
SettingKeyAuthSourceDefaultGoogleBalance: "0",
|
||||||
|
SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
|
||||||
|
SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
|
||||||
|
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
|
||||||
SettingKeyForceEmailOnThirdPartySignup: "false",
|
SettingKeyForceEmailOnThirdPartySignup: "false",
|
||||||
SettingKeySMTPPort: "587",
|
SettingKeySMTPPort: "587",
|
||||||
SettingKeySMTPUseTLS: "false",
|
SettingKeySMTPUseTLS: "false",
|
||||||
@ -1903,6 +2348,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
// Affiliate (邀请返利) feature (default disabled; opt-in)
|
// Affiliate (邀请返利) feature (default disabled; opt-in)
|
||||||
SettingKeyAffiliateEnabled: "false",
|
SettingKeyAffiliateEnabled: "false",
|
||||||
|
|
||||||
|
// 风控中心功能(默认关闭,显式启用)
|
||||||
|
SettingKeyRiskControlEnabled: "false",
|
||||||
|
|
||||||
// Claude Code version check (default: empty = disabled)
|
// Claude Code version check (default: empty = disabled)
|
||||||
SettingKeyMinClaudeCodeVersion: "",
|
SettingKeyMinClaudeCodeVersion: "",
|
||||||
SettingKeyMaxClaudeCodeVersion: "",
|
SettingKeyMaxClaudeCodeVersion: "",
|
||||||
@ -1923,6 +2371,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
// parseSettings 解析设置到结构体
|
// parseSettings 解析设置到结构体
|
||||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||||
|
loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments])
|
||||||
|
loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt])
|
||||||
|
if loginAgreementUpdatedAt == "" {
|
||||||
|
loginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||||
|
}
|
||||||
result := &SystemSettings{
|
result := &SystemSettings{
|
||||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||||
EmailVerifyEnabled: emailVerifyEnabled,
|
EmailVerifyEnabled: emailVerifyEnabled,
|
||||||
@ -1932,6 +2385,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
FrontendURL: settings[SettingKeyFrontendURL],
|
FrontendURL: settings[SettingKeyFrontendURL],
|
||||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||||
|
LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true",
|
||||||
|
LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]),
|
||||||
|
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||||
|
LoginAgreementDocuments: loginAgreementDocuments,
|
||||||
SMTPHost: settings[SettingKeySMTPHost],
|
SMTPHost: settings[SettingKeySMTPHost],
|
||||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||||
@ -2173,6 +2630,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
}
|
}
|
||||||
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
|
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
|
||||||
|
|
||||||
|
gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github")
|
||||||
|
result.GitHubOAuthEnabled = gitHubEffective.Enabled
|
||||||
|
result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID)
|
||||||
|
result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret)
|
||||||
|
result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != ""
|
||||||
|
result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL)
|
||||||
|
result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL)
|
||||||
|
|
||||||
|
googleEffective := s.effectiveEmailOAuthConfig(settings, "google")
|
||||||
|
result.GoogleOAuthEnabled = googleEffective.Enabled
|
||||||
|
result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID)
|
||||||
|
result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret)
|
||||||
|
result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != ""
|
||||||
|
result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL)
|
||||||
|
result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL)
|
||||||
|
|
||||||
// WeChat Connect 设置:
|
// WeChat Connect 设置:
|
||||||
// - 优先读取 DB 系统设置
|
// - 优先读取 DB 系统设置
|
||||||
// - 缺失时回退到 config/env,保持升级兼容
|
// - 缺失时回退到 config/env,保持升级兼容
|
||||||
@ -2242,6 +2715,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
// Affiliate (邀请返利) feature (default: disabled; strict true)
|
// Affiliate (邀请返利) feature (default: disabled; strict true)
|
||||||
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
|
result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
|
||||||
|
|
||||||
|
// 风控中心功能(默认关闭,严格 true 才启用)
|
||||||
|
result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true"
|
||||||
|
|
||||||
// Claude Code version check
|
// Claude Code version check
|
||||||
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
|
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
|
||||||
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
|
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
|
||||||
|
|||||||
@ -20,6 +20,10 @@ type SystemSettings struct {
|
|||||||
FrontendURL string
|
FrontendURL string
|
||||||
InvitationCodeEnabled bool
|
InvitationCodeEnabled bool
|
||||||
TotpEnabled bool // TOTP 双因素认证
|
TotpEnabled bool // TOTP 双因素认证
|
||||||
|
LoginAgreementEnabled bool
|
||||||
|
LoginAgreementMode string
|
||||||
|
LoginAgreementUpdatedAt string
|
||||||
|
LoginAgreementDocuments []LoginAgreementDocument
|
||||||
|
|
||||||
SMTPHost string
|
SMTPHost string
|
||||||
SMTPPort int
|
SMTPPort int
|
||||||
@ -89,6 +93,20 @@ type SystemSettings struct {
|
|||||||
OIDCConnectUserInfoIDPath string
|
OIDCConnectUserInfoIDPath string
|
||||||
OIDCConnectUserInfoUsernamePath string
|
OIDCConnectUserInfoUsernamePath string
|
||||||
|
|
||||||
|
// GitHub / Google 邮箱快捷登录
|
||||||
|
GitHubOAuthEnabled bool
|
||||||
|
GitHubOAuthClientID string
|
||||||
|
GitHubOAuthClientSecret string
|
||||||
|
GitHubOAuthClientSecretConfigured bool
|
||||||
|
GitHubOAuthRedirectURL string
|
||||||
|
GitHubOAuthFrontendRedirectURL string
|
||||||
|
GoogleOAuthEnabled bool
|
||||||
|
GoogleOAuthClientID string
|
||||||
|
GoogleOAuthClientSecret string
|
||||||
|
GoogleOAuthClientSecretConfigured bool
|
||||||
|
GoogleOAuthRedirectURL string
|
||||||
|
GoogleOAuthFrontendRedirectURL string
|
||||||
|
|
||||||
SiteName string
|
SiteName string
|
||||||
SiteLogo string
|
SiteLogo string
|
||||||
SiteSubtitle string
|
SiteSubtitle string
|
||||||
@ -106,6 +124,7 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
DefaultConcurrency int
|
DefaultConcurrency int
|
||||||
DefaultBalance float64
|
DefaultBalance float64
|
||||||
|
RiskControlEnabled bool
|
||||||
AffiliateEnabled bool
|
AffiliateEnabled bool
|
||||||
AffiliateRebateRate float64
|
AffiliateRebateRate float64
|
||||||
AffiliateRebateFreezeHours int
|
AffiliateRebateFreezeHours int
|
||||||
@ -190,6 +209,11 @@ type PublicSettings struct {
|
|||||||
PasswordResetEnabled bool
|
PasswordResetEnabled bool
|
||||||
InvitationCodeEnabled bool
|
InvitationCodeEnabled bool
|
||||||
TotpEnabled bool // TOTP 双因素认证
|
TotpEnabled bool // TOTP 双因素认证
|
||||||
|
LoginAgreementEnabled bool
|
||||||
|
LoginAgreementMode string
|
||||||
|
LoginAgreementUpdatedAt string
|
||||||
|
LoginAgreementRevision string
|
||||||
|
LoginAgreementDocuments []LoginAgreementDocument
|
||||||
TurnstileEnabled bool
|
TurnstileEnabled bool
|
||||||
TurnstileSiteKey string
|
TurnstileSiteKey string
|
||||||
SiteName string
|
SiteName string
|
||||||
@ -217,6 +241,8 @@ type PublicSettings struct {
|
|||||||
PaymentEnabled bool
|
PaymentEnabled bool
|
||||||
OIDCOAuthEnabled bool
|
OIDCOAuthEnabled bool
|
||||||
OIDCOAuthProviderName string
|
OIDCOAuthProviderName string
|
||||||
|
GitHubOAuthEnabled bool
|
||||||
|
GoogleOAuthEnabled bool
|
||||||
Version string
|
Version string
|
||||||
|
|
||||||
BalanceLowNotifyEnabled bool
|
BalanceLowNotifyEnabled bool
|
||||||
@ -233,6 +259,15 @@ type PublicSettings struct {
|
|||||||
|
|
||||||
// Affiliate (邀请返利) feature toggle
|
// Affiliate (邀请返利) feature toggle
|
||||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||||
|
|
||||||
|
// 风控中心功能开关
|
||||||
|
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginAgreementDocument struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
ContentMD string `json:"content_md"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WeChatConnectOAuthConfig struct {
|
type WeChatConnectOAuthConfig struct {
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,6 +96,9 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
|||||||
// NeedsRefresh 检查token是否需要刷新
|
// NeedsRefresh 检查token是否需要刷新
|
||||||
// expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
|
// expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期
|
||||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||||
|
if strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
if expiresAt == nil {
|
if expiresAt == nil {
|
||||||
return account.IsRateLimited()
|
return account.IsRateLimited()
|
||||||
|
|||||||
@ -97,6 +97,8 @@ type UserRepository interface {
|
|||||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||||
|
BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error)
|
||||||
|
BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error)
|
||||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||||
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
|
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
|
||||||
|
|||||||
@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
|
|||||||
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
|
func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||||
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||||
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||||
out := make([]UserAuthIdentityRecord, len(m.identities))
|
out := make([]UserAuthIdentityRecord, len(m.identities))
|
||||||
|
|||||||
@ -515,6 +515,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewGroupCapacityService,
|
NewGroupCapacityService,
|
||||||
NewChannelService,
|
NewChannelService,
|
||||||
NewModelPricingResolver,
|
NewModelPricingResolver,
|
||||||
|
NewContentModerationService,
|
||||||
NewAffiliateService,
|
NewAffiliateService,
|
||||||
ProvidePaymentConfigService,
|
ProvidePaymentConfigService,
|
||||||
NewPaymentService,
|
NewPaymentService,
|
||||||
|
|||||||
27
backend/migrations/135_allow_email_oauth_provider_types.sql
Normal file
27
backend/migrations/135_allow_email_oauth_provider_types.sql
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
ALTER TABLE users
|
||||||
|
DROP CONSTRAINT IF EXISTS users_signup_source_check;
|
||||||
|
|
||||||
|
ALTER TABLE users
|
||||||
|
ADD CONSTRAINT users_signup_source_check
|
||||||
|
CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||||
|
|
||||||
|
ALTER TABLE auth_identities
|
||||||
|
DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check;
|
||||||
|
|
||||||
|
ALTER TABLE auth_identities
|
||||||
|
ADD CONSTRAINT auth_identities_provider_type_check
|
||||||
|
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||||
|
|
||||||
|
ALTER TABLE auth_identity_channels
|
||||||
|
DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check;
|
||||||
|
|
||||||
|
ALTER TABLE auth_identity_channels
|
||||||
|
ADD CONSTRAINT auth_identity_channels_provider_type_check
|
||||||
|
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||||
|
|
||||||
|
ALTER TABLE pending_auth_sessions
|
||||||
|
DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check;
|
||||||
|
|
||||||
|
ALTER TABLE pending_auth_sessions
|
||||||
|
ADD CONSTRAINT pending_auth_sessions_provider_type_check
|
||||||
|
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google'));
|
||||||
45
backend/migrations/135_content_moderation.sql
Normal file
45
backend/migrations/135_content_moderation.sql
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
-- 风控中心内容审计配置与记录
|
||||||
|
|
||||||
|
INSERT INTO settings (key, value, updated_at)
|
||||||
|
VALUES ('risk_control_enabled', 'false', NOW())
|
||||||
|
ON CONFLICT (key) DO NOTHING;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS content_moderation_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
request_id VARCHAR(128) NOT NULL DEFAULT '',
|
||||||
|
user_id BIGINT REFERENCES users(id) ON DELETE SET NULL,
|
||||||
|
user_email VARCHAR(255) NOT NULL DEFAULT '',
|
||||||
|
api_key_id BIGINT REFERENCES api_keys(id) ON DELETE SET NULL,
|
||||||
|
api_key_name VARCHAR(100) NOT NULL DEFAULT '',
|
||||||
|
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
|
||||||
|
group_name VARCHAR(255) NOT NULL DEFAULT '',
|
||||||
|
endpoint VARCHAR(128) NOT NULL DEFAULT '',
|
||||||
|
provider VARCHAR(64) NOT NULL DEFAULT '',
|
||||||
|
model VARCHAR(255) NOT NULL DEFAULT '',
|
||||||
|
mode VARCHAR(32) NOT NULL DEFAULT '',
|
||||||
|
action VARCHAR(32) NOT NULL DEFAULT '',
|
||||||
|
flagged BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
highest_category VARCHAR(64) NOT NULL DEFAULT '',
|
||||||
|
highest_score DECIMAL(8, 6) NOT NULL DEFAULT 0,
|
||||||
|
category_scores JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
threshold_snapshot JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
input_excerpt TEXT NOT NULL DEFAULT '',
|
||||||
|
upstream_latency_ms INT,
|
||||||
|
error TEXT NOT NULL DEFAULT '',
|
||||||
|
violation_count INT NOT NULL DEFAULT 0,
|
||||||
|
auto_banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
email_sent BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
queue_delay_ms INT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS violation_count INT NOT NULL DEFAULT 0;
|
||||||
|
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS auto_banned BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS email_sent BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS queue_delay_ms INT;
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_created_at ON content_moderation_logs(created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_group_created_at ON content_moderation_logs(group_id, created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_flagged_created_at ON content_moderation_logs(flagged, created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_user_created_at ON content_moderation_logs(user_id, created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_api_key_created_at ON content_moderation_logs(api_key_id, created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_endpoint_created_at ON content_moderation_logs(endpoint, created_at DESC);
|
||||||
@ -142,3 +142,16 @@ func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T)
|
|||||||
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
|
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
|
||||||
require.NotContains(t, sql, "detail::jsonb")
|
require.NotContains(t, sql, "detail::jsonb")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigration135AllowsGitHubAndGoogleAuthProviders(t *testing.T) {
|
||||||
|
content, err := FS.ReadFile("135_allow_email_oauth_provider_types.sql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sql := string(content)
|
||||||
|
require.Contains(t, sql, "users_signup_source_check")
|
||||||
|
require.Contains(t, sql, "auth_identities_provider_type_check")
|
||||||
|
require.Contains(t, sql, "auth_identity_channels_provider_type_check")
|
||||||
|
require.Contains(t, sql, "pending_auth_sessions_provider_type_check")
|
||||||
|
require.Contains(t, sql, "'github'")
|
||||||
|
require.Contains(t, sql, "'google'")
|
||||||
|
}
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26.2-alpine
|
ARG GOLANG_IMAGE=golang:1.26.3-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.20
|
ARG ALPINE_IMAGE=alpine:3.20
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|||||||
@ -202,6 +202,14 @@ gateway:
|
|||||||
#
|
#
|
||||||
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
|
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
|
||||||
force_codex_cli: false
|
force_codex_cli: false
|
||||||
|
# Enable Codex image-generation bridge injection for /openai/v1/responses.
|
||||||
|
# 是否为 Codex /responses 请求自动注入 image_generation 工具与桥接指令。
|
||||||
|
#
|
||||||
|
# Default false keeps text-only Codex requests text-only. Explicit client-provided
|
||||||
|
# image_generation tools are still forwarded when the group allows image generation.
|
||||||
|
# 默认 false:保持纯文本 Codex 请求不被改写;客户端显式提供 image_generation tool 时,
|
||||||
|
# 仍会在分组允许图片生成的情况下正常转发。
|
||||||
|
codex_image_generation_bridge_enabled: false
|
||||||
# Optional: template file used to build the final top-level Codex `instructions`.
|
# Optional: template file used to build the final top-level Codex `instructions`.
|
||||||
# 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。
|
# 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。
|
||||||
#
|
#
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import type {
|
|||||||
TempUnschedulableStatus,
|
TempUnschedulableStatus,
|
||||||
AdminDataPayload,
|
AdminDataPayload,
|
||||||
AdminDataImportResult,
|
AdminDataImportResult,
|
||||||
|
CodexSessionImportRequest,
|
||||||
|
CodexSessionImportResult,
|
||||||
CheckMixedChannelRequest,
|
CheckMixedChannelRequest,
|
||||||
CheckMixedChannelResponse
|
CheckMixedChannelResponse
|
||||||
} from '@/types'
|
} from '@/types'
|
||||||
@ -547,6 +549,11 @@ export async function importData(payload: {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function importCodexSession(payload: CodexSessionImportRequest): Promise<CodexSessionImportResult> {
|
||||||
|
const { data } = await apiClient.post<CodexSessionImportResult>('/admin/accounts/import/codex-session', payload)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get Antigravity default model mapping from backend
|
* Get Antigravity default model mapping from backend
|
||||||
* @returns Default model mapping (from -> to)
|
* @returns Default model mapping (from -> to)
|
||||||
@ -663,6 +670,7 @@ export const accountsAPI = {
|
|||||||
syncFromCrs,
|
syncFromCrs,
|
||||||
exportData,
|
exportData,
|
||||||
importData,
|
importData,
|
||||||
|
importCodexSession,
|
||||||
getAntigravityDefaultModelMapping,
|
getAntigravityDefaultModelMapping,
|
||||||
batchClearError,
|
batchClearError,
|
||||||
batchRefresh,
|
batchRefresh,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user