diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 95e0c9b1..c4fe8805 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -20,7 +20,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.26.2' + go version | grep -q 'go1.26.3' - name: Unit tests working-directory: backend run: make test-unit @@ -60,7 +60,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.26.2' + go version | grep -q 'go1.26.3' - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 26ed8524..80bc9850 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,7 +115,7 @@ jobs: - name: Verify Go version run: | - go version | grep -q 'go1.26.2' + go version | grep -q 'go1.26.3' # Docker setup for GoReleaser - name: Set up QEMU diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 600fd2fa..ef8e59e5 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -23,7 +23,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.26.2' + go version | grep -q 'go1.26.3' - name: Run govulncheck working-directory: backend run: | diff --git a/Dockerfile b/Dockerfile index 26937df1..7befb464 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26-alpine +ARG GOLANG_IMAGE=golang:1.26.3-alpine ARG ALPINE_IMAGE=alpine:3.21 ARG POSTGRES_IMAGE=postgres:18-alpine ARG GOPROXY=https://goproxy.cn,direct diff --git a/README.md b/README.md index 718730c6..bdb09d15 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot -silkapi -Thanks to SilkAPI for sponsoring this project! SilkAPI is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay. +silkapi +Thanks to SilkAPI for sponsoring this project! SilkAPI is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay. diff --git a/README_CN.md b/README_CN.md index 24600e0e..e13f86de 100644 --- a/README_CN.md +++ b/README_CN.md @@ -71,8 +71,8 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 -silkapi -感谢 丝绸API 赞助了本项目! 丝绸API 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。 +silkapi +感谢 丝绸API 赞助了本项目! 丝绸API 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。 diff --git a/README_JA.md b/README_JA.md index 1e89610c..73331a07 100644 --- a/README_JA.md +++ b/README_JA.md @@ -71,8 +71,8 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを -silkapi -SilkAPI のご支援に感謝します!SilkAPI は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。 +silkapi +SilkAPI のご支援に感謝します!SilkAPI は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。 diff --git a/backend/Dockerfile b/backend/Dockerfile index b6823a80..f153d686 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.26-alpine +FROM golang:1.26.3-alpine WORKDIR /app diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 2c88ab23..5076ee80 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.123 +0.1.125 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 719dc3b7..239f2be9 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -8,11 +8,6 @@ package main import ( "context" - "log" - "net/http" - "sync" - "time" - "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -23,9 +18,14 @@ import ( "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" + "log" + "net/http" + "sync" + "time" +) +import ( _ "embed" - _ "github.com/Wei-Shaw/sub2api/ent/runtime" ) @@ -74,7 +74,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService) secretEncryptor, err := repository.NewAESEncryptor(configConfig) if err != nil { return nil, err @@ -136,26 +136,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) + antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) + antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider) usageCache := service.NewUsageCache() identityCache := repository.NewIdentityCache(redisClient) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) - oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) - antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) - antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository, antigravityTokenProvider) - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) windsurfLSService := service.ProvideWindsurfLSService(configConfig) windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository) windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache) - windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) @@ -237,19 +236,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) + contentModerationRepository := repository.NewContentModerationRepository(db) + contentModerationHashCache := repository.NewContentModerationHashCache(redisClient) + contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService) + contentModerationHandler := admin.NewContentModerationHandler(contentModerationService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService) - windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository) windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository) windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService) affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, windsurfHandler, affiliateHandler) + windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) @@ -274,6 +277,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) + windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService) application := &Application{ @@ -485,6 +489,12 @@ func provideCleanup( } return nil }}, + {"WindsurfLSService", func() error { + if windsurfLS != nil { + windsurfLS.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go index 0b1b56ab..5f864080 100644 --- a/backend/ent/schema/auth_identity.go +++ b/backend/ent/schema/auth_identity.go @@ -16,6 +16,8 @@ import ( var authProviderTypes = map[string]struct{}{ "email": {}, + "github": {}, + "google": {}, "linuxdo": {}, "oidc": {}, "wechat": {}, diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go index fbb93236..d3e24050 100644 --- a/backend/ent/schema/auth_identity_schema_test.go +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -83,10 +83,10 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) { require.Equal(t, 1, signupSource.Validators) validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source") - for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} { + for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} { require.NoError(t, validator(value)) } - require.Error(t, validator("github")) + require.Error(t, validator("unknown")) } func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 83da5c32..08bab83a 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -77,10 +77,10 @@ func (User) Fields() []ent.Field { field.String("signup_source"). Validate(func(value string) error { switch value { - case "email", "linuxdo", "wechat", "oidc": + case "email", "linuxdo", "wechat", "oidc", "github", "google": return nil default: - return fmt.Errorf("must be one of email, linuxdo, wechat, oidc") + return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google") } }). Default("email"), diff --git a/backend/go.mod b/backend/go.mod index 1a7e2dd1..74fb4a6e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,6 +1,6 @@ module github.com/Wei-Shaw/sub2api -go 1.26.2 +go 1.26.3 require ( connectrpc.com/connect v1.19.2 @@ -21,6 +21,7 @@ require ( github.com/google/wire v0.7.0 github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 + github.com/klauspost/compress v1.18.2 github.com/lib/pq v1.10.9 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pquerna/otp v1.5.0 @@ -40,11 +41,11 @@ require ( github.com/wechatpay-apiv3/wechatpay-go v0.2.21 github.com/zeromicro/go-zero v1.9.4 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.49.0 + golang.org/x/crypto v0.50.0 golang.org/x/image v0.39.0 - golang.org/x/net v0.52.0 + golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 - golang.org/x/term v0.41.0 + golang.org/x/term v0.42.0 google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 @@ -112,7 +113,6 @@ require ( github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/icholy/digest v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect @@ -176,7 +176,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/mod v0.34.0 // indirect - golang.org/x/sys v0.42.0 // indirect + golang.org/x/sys v0.43.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/tools v0.43.0 // indirect google.golang.org/grpc v1.75.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index b2699f2b..130d8eb4 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -187,8 +187,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -224,8 +222,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -259,8 +255,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -290,8 +284,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -324,8 +316,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -417,16 +407,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= -golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww= golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -438,10 +428,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= -golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index fbbab69a..b0c105c6 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -79,6 +79,8 @@ type Config struct { LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` WeChat WeChatConnectConfig `mapstructure:"wechat_connect"` OIDC OIDCConnectConfig `mapstructure:"oidc_connect"` + GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"` + GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` Pricing PricingConfig `mapstructure:"pricing"` @@ -248,6 +250,19 @@ type OIDCConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +type EmailOAuthProviderConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + EmailsURL string `mapstructure:"emails_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` +} + const ( defaultWeChatConnectMode = "open" defaultWeChatConnectScopes = "snsapi_login" @@ -619,6 +634,9 @@ type GatewayConfig struct { // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // CodexImageGenerationBridgeEnabled: 是否为 Codex `/v1/responses` 自动注入 image_generation 工具和桥接指令。 + // 默认关闭,避免纯文本 Codex 请求被意外改写;显式携带 image_generation 工具的请求仍按分组能力转发。 + CodexImageGenerationBridgeEnabled bool `mapstructure:"codex_image_generation_bridge_enabled"` // ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。 // 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。 ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"` @@ -1773,6 +1791,7 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.codex_image_generation_bridge_enabled", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) viper.SetDefault("gateway.openai_ws.enabled", true) diff --git a/backend/internal/handler/admin/account_codex_import.go b/backend/internal/handler/admin/account_codex_import.go new file mode 100644 index 00000000..0c599522 --- /dev/null +++ b/backend/internal/handler/admin/account_codex_import.go @@ -0,0 +1,1045 @@ +package admin + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const codexImportClockSkewSeconds int64 = 120 + +type CodexSessionImportRequest struct { + Content string `json:"content"` + Contents []string `json:"contents"` + Name string `json:"name"` + Notes *string `json:"notes"` + GroupIDs []int64 `json:"group_ids"` + ProxyID *int64 `json:"proxy_id"` + Concurrency *int `json:"concurrency"` + Priority *int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + CredentialExtras map[string]any `json:"credential_extras"` + Extra map[string]any `json:"extra"` + UpdateExisting *bool `json:"update_existing"` + SkipDefaultGroupBind *bool `json:"skip_default_group_bind"` + ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` +} + +type CodexSessionImportResult struct { + Total int `json:"total"` + Created int `json:"created"` + Updated int `json:"updated"` + Skipped int `json:"skipped"` + Failed int `json:"failed"` + Items []CodexSessionImportItem `json:"items,omitempty"` + Warnings []CodexSessionImportMessage `json:"warnings,omitempty"` + Errors []CodexSessionImportMessage `json:"errors,omitempty"` +} + +type CodexSessionImportItem struct { + Index int `json:"index"` + Name string `json:"name,omitempty"` + Action string `json:"action"` + AccountID int64 `json:"account_id,omitempty"` + Message string `json:"message,omitempty"` +} + +type CodexSessionImportMessage struct { + Index int `json:"index"` + Name string `json:"name,omitempty"` + Message string `json:"message"` +} + +type codexImportEntry struct { + Index int + Value any +} + +type codexImportAccount struct { + Name string + AccessToken string + RefreshToken string + IDToken string + Email string + AccountID string + UserID string + PlanType string + Organization string + Credentials map[string]any + Extra map[string]any + TokenExpiresAt *time.Time + IdentityKeys []string + WarningTexts []string +} + +type codexJWTClaims struct { + Sub string `json:"sub"` + Email string `json:"email"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + OpenAIAuth *codexJWTOpenAIClaims `json:"https://api.openai.com/auth,omitempty"` +} + +type codexJWTOpenAIClaims struct { + ChatGPTAccountID string `json:"chatgpt_account_id"` + ChatGPTUserID string `json:"chatgpt_user_id"` + ChatGPTPlanType string `json:"chatgpt_plan_type"` + UserID string `json:"user_id"` + POID string `json:"poid"` + Organizations []openai.OrganizationClaim `json:"organizations"` +} + +type codexAccountIndex struct { + accountsByKey map[string]service.Account +} + +func (h *AccountHandler) ImportCodexSession(c *gin.Context) { + var req CodexSessionImportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if req.Concurrency != nil && *req.Concurrency < 0 { + response.BadRequest(c, "concurrency must be >= 0") + return + } + if req.Priority != nil && *req.Priority < 0 { + response.BadRequest(c, "priority must be >= 0") + return + } + if req.RateMultiplier != nil && *req.RateMultiplier < 0 { + response.BadRequest(c, "rate_multiplier must be >= 0") + return + } + if req.LoadFactor != nil && *req.LoadFactor > 10000 { + response.BadRequest(c, "load_factor must be <= 10000") + return + } + + entries, err := parseCodexSessionImportEntries(req) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + if len(entries) == 0 { + response.BadRequest(c, "请输入 accessToken 或 Codex session JSON") + return + } + + executeAdminIdempotentJSON(c, "admin.accounts.import_codex_session", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + return h.importCodexSessions(ctx, req, entries) + }) +} + +func (h *AccountHandler) importCodexSessions(ctx context.Context, req CodexSessionImportRequest, entries []codexImportEntry) (CodexSessionImportResult, error) { + result := CodexSessionImportResult{ + Total: len(entries), + Items: make([]CodexSessionImportItem, 0, len(entries)), + } + + existingAccounts, err := h.listAccountsFiltered(ctx, service.PlatformOpenAI, service.AccountTypeOAuth, "", "", 0, "", "created_at", "desc") + if err != nil { + return result, err + } + index := buildCodexAccountIndex(existingAccounts) + + updateExisting := true + if req.UpdateExisting != nil { + updateExisting = *req.UpdateExisting + } + concurrency := 3 + if req.Concurrency != nil { + concurrency = *req.Concurrency + } + priority := 50 + if req.Priority != nil { + priority = *req.Priority + } + credentialExtras := sanitizeCodexImportCredentialExtras(req.CredentialExtras) + skipDefaultGroupBind := false + if req.SkipDefaultGroupBind != nil { + skipDefaultGroupBind = *req.SkipDefaultGroupBind + } + skipMixedChannelCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk + + seenIdentity := map[string]int{} + for _, entry := range entries { + item, err := normalizeCodexImportEntry(entry) + if err != nil { + result.Failed++ + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Action: "failed", + Message: err.Error(), + }) + result.Errors = append(result.Errors, CodexSessionImportMessage{ + Index: entry.Index, + Message: err.Error(), + }) + continue + } + accountName := buildCodexCreateAccountName(req.Name, item, entry.Index, len(entries)) + effectiveExpiresAt, credentialExpiresAt, autoPauseOnExpired, expiryWarnings, expiryErr := resolveCodexImportExpiry(req, item) + if expiryErr != nil { + result.Failed++ + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "failed", + Message: expiryErr.Error(), + }) + result.Errors = append(result.Errors, CodexSessionImportMessage{ + Index: entry.Index, + Name: accountName, + Message: expiryErr.Error(), + }) + continue + } + item.WarningTexts = append(item.WarningTexts, expiryWarnings...) + if credentialExpiresAt != nil { + item.Credentials["expires_at"] = credentialExpiresAt.Format(time.RFC3339) + } + credentials := mergeCodexImportMap(item.Credentials, credentialExtras) + extra := mergeCodexImportMap(req.Extra, item.Extra) + for _, warning := range item.WarningTexts { + result.Warnings = append(result.Warnings, CodexSessionImportMessage{ + Index: entry.Index, + Name: accountName, + Message: warning, + }) + } + + if duplicateIndex, ok := firstSeenCodexIdentity(seenIdentity, item.IdentityKeys); ok { + message := fmt.Sprintf("与第 %d 条导入项重复,已跳过", duplicateIndex) + result.Skipped++ + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "skipped", + Message: message, + }) + result.Warnings = append(result.Warnings, CodexSessionImportMessage{ + Index: entry.Index, + Name: accountName, + Message: message, + }) + continue + } + markCodexIdentitySeen(seenIdentity, item.IdentityKeys, entry.Index) + + if existing := index.Find(item.IdentityKeys); existing != nil && updateExisting { + mergedCredentials := mergeCodexImportCredentials(existing.Credentials, credentials, item) + mergedExtra := mergeCodexImportMap(existing.Extra, extra) + updateInput := &service.UpdateAccountInput{ + Credentials: mergedCredentials, + Extra: mergedExtra, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, + ExpiresAt: effectiveExpiresAt, + AutoPauseOnExpired: autoPauseOnExpired, + } + if req.ProxyID != nil { + updateInput.ProxyID = req.ProxyID + } + if len(req.GroupIDs) > 0 { + groupIDs := append([]int64(nil), req.GroupIDs...) + updateInput.GroupIDs = &groupIDs + updateInput.SkipMixedChannelCheck = skipMixedChannelCheck + } + updated, updateErr := h.adminService.UpdateAccount(ctx, existing.ID, updateInput) + if updateErr != nil { + result.Failed++ + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "failed", + Message: updateErr.Error(), + }) + result.Errors = append(result.Errors, CodexSessionImportMessage{ + Index: entry.Index, + Name: accountName, + Message: updateErr.Error(), + }) + continue + } + if h.tokenCacheInvalidator != nil && updated != nil { + _ = h.tokenCacheInvalidator.InvalidateToken(ctx, updated) + } + result.Updated++ + accountID := existing.ID + if updated != nil { + accountID = updated.ID + index.Add(*updated) + } + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "updated", + AccountID: accountID, + }) + continue + } + + account, createErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: accountName, + Notes: req.Notes, + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Credentials: credentials, + Extra: extra, + ProxyID: req.ProxyID, + Concurrency: concurrency, + Priority: priority, + RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, + GroupIDs: req.GroupIDs, + ExpiresAt: effectiveExpiresAt, + AutoPauseOnExpired: autoPauseOnExpired, + SkipDefaultGroupBind: skipDefaultGroupBind, + SkipMixedChannelCheck: skipMixedChannelCheck, + }) + if createErr != nil { + result.Failed++ + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "failed", + Message: createErr.Error(), + }) + result.Errors = append(result.Errors, CodexSessionImportMessage{ + Index: entry.Index, + Name: accountName, + Message: createErr.Error(), + }) + continue + } + if account != nil { + index.Add(*account) + } + result.Created++ + accountID := int64(0) + if account != nil { + accountID = account.ID + } + result.Items = append(result.Items, CodexSessionImportItem{ + Index: entry.Index, + Name: accountName, + Action: "created", + AccountID: accountID, + }) + } + + return result, nil +} + +func parseCodexSessionImportEntries(req CodexSessionImportRequest) ([]codexImportEntry, error) { + contents := make([]string, 0, 1+len(req.Contents)) + if strings.TrimSpace(req.Content) != "" { + contents = append(contents, req.Content) + } + for _, content := range req.Contents { + if strings.TrimSpace(content) != "" { + contents = append(contents, content) + } + } + + var entries []codexImportEntry + for _, content := range contents { + values, err := parseCodexSessionImportContent(content) + if err != nil { + return nil, err + } + for _, value := range values { + entries = append(entries, codexImportEntry{ + Index: len(entries) + 1, + Value: value, + }) + } + } + return entries, nil +} + +func parseCodexSessionImportContent(content string) ([]any, error) { + trimmed := strings.TrimSpace(content) + if trimmed == "" { + return nil, nil + } + + if looksLikeJSON(trimmed) { + values, err := decodeCodexJSONStream(trimmed) + if err != nil { + if strings.Contains(trimmed, "\n") { + if lineValues, lineErr := parseCodexSessionImportLines(trimmed); lineErr == nil { + return lineValues, nil + } + } + return nil, fmt.Errorf("JSON 解析失败: %w", err) + } + return flattenCodexImportValues(values), nil + } + + return parseCodexSessionImportLines(trimmed) +} + +func parseCodexSessionImportLines(content string) ([]any, error) { + values := make([]any, 0) + for _, line := range strings.Split(content, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if looksLikeJSON(line) { + lineValues, err := decodeCodexJSONStream(line) + if err != nil { + return nil, fmt.Errorf("第 %d 行 JSON 解析失败: %w", len(values)+1, err) + } + values = append(values, flattenCodexImportValues(lineValues)...) + continue + } + values = append(values, line) + } + return values, nil +} + +func decodeCodexJSONStream(content string) ([]any, error) { + decoder := json.NewDecoder(strings.NewReader(content)) + decoder.UseNumber() + values := make([]any, 0, 1) + for { + var value any + err := decoder.Decode(&value) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + values = append(values, value) + } + if len(values) == 0 { + return nil, errors.New("空 JSON 内容") + } + return values, nil +} + +func flattenCodexImportValues(values []any) []any { + out := make([]any, 0, len(values)) + var appendValue func(any) + appendValue = func(value any) { + if arr, ok := value.([]any); ok { + for _, item := range arr { + appendValue(item) + } + return + } + out = append(out, value) + } + for _, value := range values { + appendValue(value) + } + return out +} + +func normalizeCodexImportEntry(entry codexImportEntry) (*codexImportAccount, error) { + now := time.Now().UTC() + item := &codexImportAccount{ + Credentials: map[string]any{}, + Extra: map[string]any{ + "import_source": "codex_session", + "imported_at": now.Format(time.RFC3339), + }, + } + + switch raw := entry.Value.(type) { + case string: + item.AccessToken = strings.TrimSpace(raw) + case map[string]any: + item.AccessToken = firstCodexString(raw, + []string{"tokens", "access_token"}, + []string{"tokens", "accessToken"}, + []string{"access_token"}, + []string{"accessToken"}, + []string{"token"}, + ) + item.RefreshToken = firstCodexString(raw, + []string{"tokens", "refresh_token"}, + []string{"tokens", "refreshToken"}, + []string{"refresh_token"}, + []string{"refreshToken"}, + ) + item.IDToken = firstCodexString(raw, + []string{"tokens", "id_token"}, + []string{"tokens", "idToken"}, + []string{"id_token"}, + []string{"idToken"}, + ) + item.Email = firstCodexString(raw, []string{"email"}, []string{"user", "email"}) + item.AccountID = firstCodexString(raw, + []string{"chatgpt_account_id"}, + []string{"chatgptAccountId"}, + []string{"account_id"}, + []string{"accountId"}, + []string{"account", "id"}, + []string{"account", "account_id"}, + []string{"account", "chatgpt_account_id"}, + ) + item.UserID = firstCodexString(raw, + []string{"chatgpt_user_id"}, + []string{"chatgptUserId"}, + []string{"user_id"}, + []string{"userId"}, + []string{"user", "id"}, + ) + item.PlanType = firstCodexString(raw, + []string{"plan_type"}, + []string{"planType"}, + []string{"account", "plan_type"}, + []string{"account", "planType"}, + ) + item.Organization = firstCodexString(raw, + []string{"organization_id"}, + []string{"organizationId"}, + []string{"org_id"}, + []string{"orgId"}, + ) + item.Name = firstCodexString(raw, []string{"name"}, []string{"user", "name"}) + authProvider := firstCodexString(raw, []string{"auth_provider"}, []string{"authProvider"}) + if authProvider != "" { + item.Extra["auth_provider"] = authProvider + } + if sessionToken := firstCodexString(raw, []string{"session_token"}, []string{"sessionToken"}); sessionToken != "" { + item.Extra["session_token_present"] = true + item.WarningTexts = append(item.WarningTexts, "sessionToken 已忽略,不会作为 OAuth refresh_token 存储") + } + if sessionExpiresAt, ok := codexTimeAt(raw, []string{"expires"}); ok { + item.Extra["session_expires_at"] = sessionExpiresAt.Format(time.RFC3339) + } + if tokenExpiresAt, ok := firstCodexTime(raw, + []string{"tokens", "expires_at"}, + []string{"tokens", "expiresAt"}, + []string{"expires_at"}, + []string{"expiresAt"}, + ); ok { + if tokenExpiresAt.Unix() <= now.Unix()-codexImportClockSkewSeconds { + return nil, fmt.Errorf("access_token 已过期: %s", tokenExpiresAt.Format(time.RFC3339)) + } + item.TokenExpiresAt = &tokenExpiresAt + item.Credentials["expires_at"] = tokenExpiresAt.Format(time.RFC3339) + } + copyCodexExtraString(raw, item.Extra, "user_image", []string{"user", "image"}) + copyCodexExtraString(raw, item.Extra, "user_picture", []string{"user", "picture"}) + copyCodexExtraString(raw, item.Extra, "account_structure", []string{"account", "structure"}) + copyCodexExtraString(raw, item.Extra, "account_residency_region", []string{"account", "residencyRegion"}) + copyCodexExtraString(raw, item.Extra, "compute_residency", []string{"account", "computeResidency"}) + default: + return nil, fmt.Errorf("第 %d 条格式不支持", entry.Index) + } + + if item.AccessToken == "" { + return nil, errors.New("缺少 accessToken/access_token") + } + item.Credentials["access_token"] = item.AccessToken + if item.RefreshToken != "" { + item.Credentials["refresh_token"] = item.RefreshToken + item.Credentials["client_id"] = openai.ClientID + } + if item.IDToken != "" { + item.Credentials["id_token"] = item.IDToken + if err := enrichCodexImportAccountFromJWT(item, item.IDToken, false, now); err != nil { + return nil, err + } + } + if err := enrichCodexImportAccountFromJWT(item, item.AccessToken, true, now); err != nil { + return nil, err + } + if _, ok := item.Credentials["expires_at"]; !ok { + item.WarningTexts = append(item.WarningTexts, "无法从 accessToken 解析过期时间,导入后需自行确认令牌有效性") + } + if item.RefreshToken == "" { + item.WarningTexts = append(item.WarningTexts, "未包含 refresh_token,accessToken 过期后无法自动续期") + } + + setCodexCredentialIfNotEmpty(item.Credentials, "email", item.Email) + setCodexCredentialIfNotEmpty(item.Credentials, "chatgpt_account_id", item.AccountID) + setCodexCredentialIfNotEmpty(item.Credentials, "chatgpt_user_id", item.UserID) + setCodexCredentialIfNotEmpty(item.Credentials, "organization_id", item.Organization) + setCodexCredentialIfNotEmpty(item.Credentials, "plan_type", item.PlanType) + + fingerprint := codexTokenFingerprint(item.AccessToken) + item.Extra["access_token_sha256"] = fingerprint + item.IdentityKeys = buildCodexIdentityKeys(item.AccountID, item.UserID, item.Email, item.AccessToken) + item.Name = buildCodexImportAccountName(item, entry.Index) + + return item, nil +} + +func enrichCodexImportAccountFromJWT(item *codexImportAccount, token string, validateExpiry bool, now time.Time) error { + claims, err := decodeCodexJWTClaims(token) + if err != nil { + if validateExpiry { + item.WarningTexts = append(item.WarningTexts, "accessToken 不是可解析 JWT,无法校验过期时间和账号身份") + } + return nil + } + if validateExpiry && claims.Exp > 0 { + if now.Unix() > claims.Exp+codexImportClockSkewSeconds { + return fmt.Errorf("access_token 已过期: %s", time.Unix(claims.Exp, 0).UTC().Format(time.RFC3339)) + } + expiresAt := time.Unix(claims.Exp, 0).UTC() + item.TokenExpiresAt = &expiresAt + item.Credentials["expires_at"] = expiresAt.Format(time.RFC3339) + } + if item.Email == "" { + item.Email = strings.TrimSpace(claims.Email) + } + if claims.OpenAIAuth == nil { + if item.UserID == "" { + item.UserID = strings.TrimSpace(claims.Sub) + } + return nil + } + if item.AccountID == "" { + item.AccountID = strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID) + } + if item.UserID == "" { + item.UserID = strings.TrimSpace(claims.OpenAIAuth.ChatGPTUserID) + } + if item.UserID == "" { + item.UserID = strings.TrimSpace(claims.OpenAIAuth.UserID) + } + if item.PlanType == "" { + item.PlanType = strings.TrimSpace(claims.OpenAIAuth.ChatGPTPlanType) + } + if item.Organization == "" { + item.Organization = strings.TrimSpace(claims.OpenAIAuth.POID) + } + if item.Organization == "" { + for _, org := range claims.OpenAIAuth.Organizations { + if org.IsDefault { + item.Organization = org.ID + break + } + } + } + if item.Organization == "" && len(claims.OpenAIAuth.Organizations) > 0 { + item.Organization = claims.OpenAIAuth.Organizations[0].ID + } + if item.UserID == "" { + item.UserID = strings.TrimSpace(claims.Sub) + } + return nil +} + +func decodeCodexJWTClaims(token string) (*codexJWTClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format") + } + payload, err := decodeCodexJWTSegment(parts[1]) + if err != nil { + return nil, err + } + var claims codexJWTClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, err + } + return &claims, nil +} + +func decodeCodexJWTSegment(segment string) ([]byte, error) { + if decoded, err := base64.RawURLEncoding.DecodeString(segment); err == nil { + return decoded, nil + } + if decoded, err := base64.RawStdEncoding.DecodeString(segment); err == nil { + return decoded, nil + } + padded := segment + if rem := len(padded) % 4; rem > 0 { + padded += strings.Repeat("=", 4-rem) + } + if decoded, err := base64.URLEncoding.DecodeString(padded); err == nil { + return decoded, nil + } + return base64.StdEncoding.DecodeString(padded) +} + +func buildCodexImportAccountName(item *codexImportAccount, index int) string { + for _, candidate := range []string{item.Name, item.Email, item.AccountID, item.UserID} { + candidate = strings.TrimSpace(candidate) + if candidate != "" { + return candidate + } + } + return fmt.Sprintf("Codex 导入账号 %d", index) +} + +func buildCodexCreateAccountName(base string, item *codexImportAccount, index, total int) string { + base = strings.TrimSpace(base) + if base == "" { + if item == nil { + return fmt.Sprintf("Codex 导入账号 %d", index) + } + return item.Name + } + if total > 1 { + return fmt.Sprintf("%s #%d", base, index) + } + return base +} + +func resolveCodexImportExpiry(req CodexSessionImportRequest, item *codexImportAccount) (*int64, *time.Time, *bool, []string, error) { + if item == nil { + return nil, nil, nil, nil, errors.New("导入项为空") + } + + var requestExpiresAt *time.Time + if req.ExpiresAt != nil && *req.ExpiresAt > 0 { + t := time.Unix(*req.ExpiresAt, 0).UTC() + requestExpiresAt = &t + } + + var accountExpiresAt *time.Time + var credentialExpiresAt *time.Time + warnings := make([]string, 0, 2) + if item.RefreshToken == "" { + if item.TokenExpiresAt != nil { + tokenExpiresAt := item.TokenExpiresAt.UTC() + accountExpiresAt = &tokenExpiresAt + credentialExpiresAt = &tokenExpiresAt + } + if requestExpiresAt != nil { + accountExpiresAt = earlierCodexTime(accountExpiresAt, requestExpiresAt) + credentialExpiresAt = earlierCodexTime(credentialExpiresAt, requestExpiresAt) + } + if accountExpiresAt == nil { + return nil, nil, nil, nil, errors.New("未包含 refresh_token,且无法解析 accessToken 过期时间;请在第一步设置过期时间后再导入") + } + if accountExpiresAt.Unix() <= time.Now().UTC().Unix()-codexImportClockSkewSeconds { + return nil, nil, nil, nil, fmt.Errorf("过期时间已过期: %s", accountExpiresAt.Format(time.RFC3339)) + } + warnings = append(warnings, "未包含 refresh_token,已按 accessToken/账号过期时间设置自动停止调度") + if req.AutoPauseOnExpired != nil && !*req.AutoPauseOnExpired { + warnings = append(warnings, "未包含 refresh_token,已强制开启过期自动暂停") + } + autoPause := true + expiresAtUnix := accountExpiresAt.Unix() + return &expiresAtUnix, credentialExpiresAt, &autoPause, warnings, nil + } + + if requestExpiresAt != nil { + accountExpiresAt = requestExpiresAt + } + if item.TokenExpiresAt != nil { + tokenExpiresAt := item.TokenExpiresAt.UTC() + credentialExpiresAt = &tokenExpiresAt + } + var expiresAtUnix *int64 + if accountExpiresAt != nil { + v := accountExpiresAt.Unix() + expiresAtUnix = &v + } + return expiresAtUnix, credentialExpiresAt, req.AutoPauseOnExpired, warnings, nil +} + +func earlierCodexTime(current, candidate *time.Time) *time.Time { + if candidate == nil { + return current + } + if current == nil || candidate.Before(*current) { + t := candidate.UTC() + return &t + } + t := current.UTC() + return &t +} + +func sanitizeCodexImportCredentialExtras(input map[string]any) map[string]any { + if len(input) == 0 { + return nil + } + protected := map[string]struct{}{ + "access_token": {}, + "refresh_token": {}, + "id_token": {}, + "expires_at": {}, + "email": {}, + "chatgpt_account_id": {}, + "chatgpt_user_id": {}, + "organization_id": {}, + "plan_type": {}, + "client_id": {}, + } + out := make(map[string]any, len(input)) + for key, value := range input { + normalizedKey := strings.TrimSpace(key) + if normalizedKey == "" { + continue + } + if _, ok := protected[strings.ToLower(normalizedKey)]; ok { + continue + } + out[normalizedKey] = value + } + if len(out) == 0 { + return nil + } + return out +} + +func buildCodexIdentityKeys(accountID, userID, email, accessToken string) []string { + keys := make([]string, 0, 4) + accountID = strings.TrimSpace(accountID) + userID = strings.TrimSpace(userID) + if accountID != "" { + keys = append(keys, "account:"+accountID) + } + if userID != "" { + keys = append(keys, "user:"+userID) + } + if accountID == "" && userID == "" { + if email = strings.ToLower(strings.TrimSpace(email)); email != "" { + keys = append(keys, "email:"+email) + } + } + if accessToken = strings.TrimSpace(accessToken); accessToken != "" { + keys = append(keys, "access:"+codexTokenFingerprint(accessToken)) + } + return keys +} + +func buildCodexAccountIndex(accounts []service.Account) *codexAccountIndex { + index := &codexAccountIndex{accountsByKey: map[string]service.Account{}} + for _, account := range accounts { + index.Add(account) + } + return index +} + +func (i *codexAccountIndex) Add(account service.Account) { + if i == nil { + return + } + if i.accountsByKey == nil { + i.accountsByKey = map[string]service.Account{} + } + keys := buildCodexIdentityKeys( + codexCredentialString(account.Credentials, "chatgpt_account_id"), + codexCredentialString(account.Credentials, "chatgpt_user_id"), + codexCredentialString(account.Credentials, "email"), + codexCredentialString(account.Credentials, "access_token"), + ) + for _, key := range keys { + i.accountsByKey[key] = account + } +} + +func (i *codexAccountIndex) Find(keys []string) *service.Account { + if i == nil { + return nil + } + for _, key := range keys { + if account, ok := i.accountsByKey[key]; ok { + return &account + } + } + return nil +} + +func firstSeenCodexIdentity(seen map[string]int, keys []string) (int, bool) { + for _, key := range keys { + if index, ok := seen[key]; ok { + return index, true + } + } + return 0, false +} + +func markCodexIdentitySeen(seen map[string]int, keys []string, index int) { + for _, key := range keys { + seen[key] = index + } +} + +func mergeCodexImportMap(existing, incoming map[string]any) map[string]any { + out := make(map[string]any, len(existing)+len(incoming)) + for k, v := range existing { + out[k] = v + } + for k, v := range incoming { + out[k] = v + } + return out +} + +func mergeCodexImportCredentials(existing, incoming map[string]any, item *codexImportAccount) map[string]any { + out := mergeCodexImportMap(existing, incoming) + if item == nil { + return out + } + if strings.TrimSpace(item.RefreshToken) == "" { + delete(out, "refresh_token") + delete(out, "client_id") + } + if strings.TrimSpace(item.IDToken) == "" { + delete(out, "id_token") + } + return out +} + +func codexCredentialString(credentials map[string]any, key string) string { + if credentials == nil { + return "" + } + return codexStringValue(credentials[key]) +} + +func codexTokenFingerprint(token string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(token))) + return hex.EncodeToString(sum[:]) +} + +func looksLikeJSON(content string) bool { + if content == "" { + return false + } + switch content[0] { + case '{', '[': + return true + default: + return false + } +} + +func firstCodexString(obj map[string]any, paths ...[]string) string { + for _, path := range paths { + if value, ok := codexPathValue(obj, path); ok { + if str := codexStringValue(value); str != "" { + return str + } + } + } + return "" +} + +func copyCodexExtraString(obj map[string]any, extra map[string]any, key string, path []string) { + value := firstCodexString(obj, path) + if value != "" { + extra[key] = value + } +} + +func firstCodexTime(obj map[string]any, paths ...[]string) (time.Time, bool) { + for _, path := range paths { + if value, ok := codexTimeAt(obj, path); ok { + return value, true + } + } + return time.Time{}, false +} + +func codexTimeAt(obj map[string]any, path []string) (time.Time, bool) { + value, ok := codexPathValue(obj, path) + if !ok { + return time.Time{}, false + } + return parseCodexTimeValue(value) +} + +func codexPathValue(obj map[string]any, path []string) (any, bool) { + var current any = obj + for _, key := range path { + currentObj, ok := current.(map[string]any) + if !ok { + return nil, false + } + value, ok := currentObj[key] + if !ok { + return nil, false + } + current = value + } + return current, true +} + +func codexStringValue(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case json.Number: + return strings.TrimSpace(v.String()) + case float64: + return strings.TrimSpace(strconv.FormatFloat(v, 'f', -1, 64)) + case float32: + return strings.TrimSpace(strconv.FormatFloat(float64(v), 'f', -1, 32)) + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case int32: + return strconv.FormatInt(int64(v), 10) + default: + return "" + } +} + +func setCodexCredentialIfNotEmpty(credentials map[string]any, key, value string) { + value = strings.TrimSpace(value) + if value != "" { + credentials[key] = value + } +} + +func parseCodexTimeValue(value any) (time.Time, bool) { + switch v := value.(type) { + case string: + v = strings.TrimSpace(v) + if v == "" { + return time.Time{}, false + } + if parsed, err := time.Parse(time.RFC3339Nano, v); err == nil { + return parsed.UTC(), true + } + if n, err := strconv.ParseInt(v, 10, 64); err == nil { + return codexUnixTime(n), true + } + case json.Number: + if n, err := v.Int64(); err == nil { + return codexUnixTime(n), true + } + if f, err := v.Float64(); err == nil { + return codexUnixTime(int64(f)), true + } + case float64: + return codexUnixTime(int64(v)), true + case int: + return codexUnixTime(int64(v)), true + case int64: + return codexUnixTime(v), true + } + return time.Time{}, false +} + +func codexUnixTime(value int64) time.Time { + if value > 1_000_000_000_000 { + return time.UnixMilli(value).UTC() + } + return time.Unix(value, 0).UTC() +} diff --git a/backend/internal/handler/admin/account_codex_import_test.go b/backend/internal/handler/admin/account_codex_import_test.go new file mode 100644 index 00000000..3cf0d2bb --- /dev/null +++ b/backend/internal/handler/admin/account_codex_import_test.go @@ -0,0 +1,344 @@ +package admin + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "testing" + "time" +) + +func TestParseCodexSessionImportEntriesSupportsRawTokenJSONAndArray(t *testing.T) { + token1 := "raw-access-token-1" + token2 := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{ + "email": "json@example.com", + }) + token3 := "raw-access-token-3" + + req := CodexSessionImportRequest{ + Content: fmt.Sprintf("%s\n{\"accessToken\":%q}\n[%q]", token1, token2, token3), + } + + entries, err := parseCodexSessionImportEntries(req) + if err != nil { + t.Fatalf("parseCodexSessionImportEntries error = %v", err) + } + if len(entries) != 3 { + t.Fatalf("len(entries) = %d, want 3", len(entries)) + } + + first, err := normalizeCodexImportEntry(entries[0]) + if err != nil { + t.Fatalf("normalize raw token error = %v", err) + } + if first.Credentials["access_token"] != token1 { + t.Fatalf("raw token access_token = %v, want %s", first.Credentials["access_token"], token1) + } + + second, err := normalizeCodexImportEntry(entries[1]) + if err != nil { + t.Fatalf("normalize json token error = %v", err) + } + if second.Email != "json@example.com" { + t.Fatalf("email = %q, want json@example.com", second.Email) + } + + third, err := normalizeCodexImportEntry(entries[2]) + if err != nil { + t.Fatalf("normalize array token error = %v", err) + } + if third.Credentials["access_token"] != token3 { + t.Fatalf("array token access_token = %v, want %s", third.Credentials["access_token"], token3) + } +} + +func TestParseCodexSessionImportEntriesFallsBackToLineModeForMixedJSONAndToken(t *testing.T) { + req := CodexSessionImportRequest{ + Content: "{\"accessToken\":\"json-line-token\"}\nraw-line-token", + } + + entries, err := parseCodexSessionImportEntries(req) + if err != nil { + t.Fatalf("parseCodexSessionImportEntries error = %v", err) + } + if len(entries) != 2 { + t.Fatalf("len(entries) = %d, want 2", len(entries)) + } + + first, err := normalizeCodexImportEntry(entries[0]) + if err != nil { + t.Fatalf("normalize json line error = %v", err) + } + if first.Credentials["access_token"] != "json-line-token" { + t.Fatalf("json line access_token = %v, want json-line-token", first.Credentials["access_token"]) + } + + second, err := normalizeCodexImportEntry(entries[1]) + if err != nil { + t.Fatalf("normalize raw line error = %v", err) + } + if second.Credentials["access_token"] != "raw-line-token" { + t.Fatalf("raw line access_token = %v, want raw-line-token", second.Credentials["access_token"]) + } +} + +func TestNormalizeCodexSessionJSONExtractsCredentialsAndIgnoresSessionToken(t *testing.T) { + accessToken := buildCodexImportTestJWT(t, time.Now().Add(time.Hour), map[string]any{ + "email": "claim@example.com", + "https://api.openai.com/auth": map[string]any{ + "chatgpt_account_id": "acct-from-claim", + "chatgpt_user_id": "user-from-claim", + "chatgpt_plan_type": "plus", + "poid": "org-from-claim", + }, + }) + raw := map[string]any{ + "user": map[string]any{ + "id": "user-from-json", + "name": "Sup OO", + "email": "json@example.com", + "image": "https://example.com/avatar.png", + }, + "account": map[string]any{ + "id": "acct-from-json", + "planType": "free", + }, + "accessToken": accessToken, + "sessionToken": "secret-session-token", + "expires": "2026-08-05T13:40:42.836Z", + } + + item, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: raw}) + if err != nil { + t.Fatalf("normalizeCodexImportEntry error = %v", err) + } + if item.Credentials["access_token"] != accessToken { + t.Fatalf("access_token not stored") + } + if item.Credentials["email"] != "json@example.com" { + t.Fatalf("email = %v, want json@example.com", item.Credentials["email"]) + } + if item.Credentials["chatgpt_account_id"] != "acct-from-json" { + t.Fatalf("chatgpt_account_id = %v, want acct-from-json", item.Credentials["chatgpt_account_id"]) + } + if item.Credentials["chatgpt_user_id"] != "user-from-json" { + t.Fatalf("chatgpt_user_id = %v, want user-from-json", item.Credentials["chatgpt_user_id"]) + } + if item.Credentials["plan_type"] != "free" { + t.Fatalf("plan_type = %v, want free", item.Credentials["plan_type"]) + } + if _, ok := item.Credentials["session_token"]; ok { + t.Fatalf("session_token should not be written to credentials") + } + if item.Extra["session_token_present"] != true { + t.Fatalf("session_token_present = %v, want true", item.Extra["session_token_present"]) + } + if item.Extra["session_expires_at"] != "2026-08-05T13:40:42Z" { + t.Fatalf("session_expires_at = %v", item.Extra["session_expires_at"]) + } + if item.TokenExpiresAt == nil { + t.Fatalf("TokenExpiresAt should be parsed from accessToken") + } +} + +func TestMergeCodexImportCredentialsClearsStaleRefreshFieldsWhenIncomingHasNoRefreshToken(t *testing.T) { + existing := map[string]any{ + "access_token": "old-access-token", + "refresh_token": "old-refresh-token", + "client_id": "old-client-id", + "id_token": "old-id-token", + "model_mapping": map[string]any{"from": "existing"}, + "chatgpt_account_id": "acct-old", + "unrelated_existing": "keep", + } + incoming := map[string]any{ + "access_token": "new-access-token", + "expires_at": "2026-08-05T13:40:42Z", + "chatgpt_account_id": "acct-new", + } + item := &codexImportAccount{ + AccessToken: "new-access-token", + } + + merged := mergeCodexImportCredentials(existing, incoming, item) + + if merged["access_token"] != "new-access-token" { + t.Fatalf("access_token = %v, want new-access-token", merged["access_token"]) + } + if merged["chatgpt_account_id"] != "acct-new" { + t.Fatalf("chatgpt_account_id = %v, want acct-new", merged["chatgpt_account_id"]) + } + if _, ok := merged["refresh_token"]; ok { + t.Fatalf("refresh_token should be cleared") + } + if _, ok := merged["client_id"]; ok { + t.Fatalf("client_id should be cleared") + } + if _, ok := merged["id_token"]; ok { + t.Fatalf("id_token should be cleared") + } + if merged["unrelated_existing"] != "keep" { + t.Fatalf("unrelated_existing = %v, want keep", merged["unrelated_existing"]) + } + if _, ok := merged["model_mapping"]; !ok { + t.Fatalf("model_mapping should be preserved") + } +} + +func TestMergeCodexImportCredentialsKeepsRefreshFieldsWhenIncomingHasRefreshToken(t *testing.T) { + existing := map[string]any{ + "refresh_token": "old-refresh-token", + "client_id": "old-client-id", + "id_token": "old-id-token", + } + incoming := map[string]any{ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "client_id": "new-client-id", + "id_token": "new-id-token", + } + item := &codexImportAccount{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + IDToken: "new-id-token", + } + + merged := mergeCodexImportCredentials(existing, incoming, item) + + if merged["refresh_token"] != "new-refresh-token" { + t.Fatalf("refresh_token = %v, want new-refresh-token", merged["refresh_token"]) + } + if merged["client_id"] != "new-client-id" { + t.Fatalf("client_id = %v, want new-client-id", merged["client_id"]) + } + if merged["id_token"] != "new-id-token" { + t.Fatalf("id_token = %v, want new-id-token", merged["id_token"]) + } +} + +func TestNormalizeCodexImportRejectsExpiredAccessToken(t *testing.T) { + expiredToken := buildCodexImportTestJWT(t, time.Now().Add(-time.Hour), map[string]any{}) + + _, err := normalizeCodexImportEntry(codexImportEntry{Index: 1, Value: expiredToken}) + if err == nil { + t.Fatal("normalizeCodexImportEntry error = nil, want expired token error") + } + if !strings.Contains(err.Error(), "已过期") { + t.Fatalf("error = %v, want expired token message", err) + } +} + +func TestResolveCodexImportExpiryForNoRefreshTokenUsesTokenExpiry(t *testing.T) { + tokenExpiresAt := time.Now().Add(time.Hour).UTC() + item := &codexImportAccount{ + AccessToken: "access-token", + Credentials: map[string]any{"access_token": "access-token"}, + TokenExpiresAt: &tokenExpiresAt, + WarningTexts: []string{}, + } + disabled := false + req := CodexSessionImportRequest{AutoPauseOnExpired: &disabled} + + accountExpiresAt, credentialExpiresAt, autoPause, warnings, err := resolveCodexImportExpiry(req, item) + if err != nil { + t.Fatalf("resolveCodexImportExpiry error = %v", err) + } + if accountExpiresAt == nil || *accountExpiresAt != tokenExpiresAt.Unix() { + t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, tokenExpiresAt.Unix()) + } + if credentialExpiresAt == nil || credentialExpiresAt.Unix() != tokenExpiresAt.Unix() { + t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, tokenExpiresAt) + } + if autoPause == nil || !*autoPause { + t.Fatalf("autoPause = %v, want true", autoPause) + } + if len(warnings) == 0 { + t.Fatalf("warnings should not be empty") + } +} + +func TestResolveCodexImportExpiryForNoRefreshTokenRequiresExpiry(t *testing.T) { + item := &codexImportAccount{ + AccessToken: "opaque-access-token", + Credentials: map[string]any{"access_token": "opaque-access-token"}, + WarningTexts: []string{}, + } + + _, _, _, _, err := resolveCodexImportExpiry(CodexSessionImportRequest{}, item) + if err == nil { + t.Fatal("resolveCodexImportExpiry error = nil, want missing expiry error") + } + if !strings.Contains(err.Error(), "无法解析 accessToken 过期时间") { + t.Fatalf("error = %v, want missing expiry message", err) + } +} + +func TestResolveCodexImportExpiryForNoRefreshTokenUsesEarlierRequestExpiry(t *testing.T) { + tokenExpiresAt := time.Now().Add(2 * time.Hour).UTC() + requestExpiresAt := time.Now().Add(time.Hour).UTC() + item := &codexImportAccount{ + AccessToken: "access-token", + Credentials: map[string]any{"access_token": "access-token"}, + TokenExpiresAt: &tokenExpiresAt, + WarningTexts: []string{}, + } + reqUnix := requestExpiresAt.Unix() + req := CodexSessionImportRequest{ExpiresAt: &reqUnix} + + accountExpiresAt, credentialExpiresAt, _, _, err := resolveCodexImportExpiry(req, item) + if err != nil { + t.Fatalf("resolveCodexImportExpiry error = %v", err) + } + if accountExpiresAt == nil || *accountExpiresAt != requestExpiresAt.Unix() { + t.Fatalf("account expires_at = %v, want %d", accountExpiresAt, requestExpiresAt.Unix()) + } + if credentialExpiresAt == nil || credentialExpiresAt.Unix() != requestExpiresAt.Unix() { + t.Fatalf("credential expires_at = %v, want %s", credentialExpiresAt, requestExpiresAt) + } +} + +func TestCodexIdentityKeysPreferStrongIdentifiers(t *testing.T) { + keys := buildCodexIdentityKeys("acct-1", "user-1", "same@example.com", "token") + for _, key := range keys { + if strings.HasPrefix(key, "email:") { + t.Fatalf("strong identity should not include email fallback: %v", keys) + } + } + + keys = buildCodexIdentityKeys("", "", "same@example.com", "token") + hasEmail := false + for _, key := range keys { + if key == "email:same@example.com" { + hasEmail = true + } + } + if !hasEmail { + t.Fatalf("weak identity should include email fallback: %v", keys) + } +} + +func buildCodexImportTestJWT(t *testing.T, exp time.Time, extraClaims map[string]any) string { + t.Helper() + header := map[string]any{ + "alg": "none", + "typ": "JWT", + } + claims := map[string]any{ + "sub": "user-from-sub", + "exp": exp.Unix(), + "iat": time.Now().Unix(), + } + for k, v := range extraClaims { + claims[k] = v + } + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatalf("marshal header: %v", err) + } + claimBytes, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + return base64.RawURLEncoding.EncodeToString(headerBytes) + "." + base64.RawURLEncoding.EncodeToString(claimBytes) + "." +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index b187b47f..2fef94f1 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, return &user, nil } +func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) { + return len(userIDs), nil +} + func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) { return s.apiKeys, int64(len(s.apiKeys)), nil } diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go new file mode 100644 index 00000000..4266f5d8 --- /dev/null +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -0,0 +1,238 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type ContentModerationHandler struct { + service *service.ContentModerationService +} + +func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler { + return &ContentModerationHandler{service: svc} +} + +type contentModerationConfigRequest struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type contentModerationAPIKeyTestRequest struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type contentModerationHashRequest struct { + InputHash string `json:"input_hash"` +} + +func (h *ContentModerationHandler) GetConfig(c *gin.Context) { + cfg, err := h.service.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { + var req contentModerationConfigRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{ + Enabled: req.Enabled, + Mode: req.Mode, + BaseURL: req.BaseURL, + Model: req.Model, + APIKey: req.APIKey, + APIKeys: req.APIKeys, + APIKeysMode: req.APIKeysMode, + DeleteAPIKeyHashes: req.DeleteAPIKeyHashes, + ClearAPIKey: req.ClearAPIKey, + TimeoutMS: req.TimeoutMS, + SampleRate: req.SampleRate, + AllGroups: req.AllGroups, + GroupIDs: req.GroupIDs, + RecordNonHits: req.RecordNonHits, + WorkerCount: req.WorkerCount, + QueueSize: req.QueueSize, + BlockStatus: req.BlockStatus, + BlockMessage: req.BlockMessage, + EmailOnHit: req.EmailOnHit, + AutoBanEnabled: req.AutoBanEnabled, + BanThreshold: req.BanThreshold, + ViolationWindowHours: req.ViolationWindowHours, + RetryCount: req.RetryCount, + HitRetentionDays: req.HitRetentionDays, + NonHitRetentionDays: req.NonHitRetentionDays, + PreHashCheckEnabled: req.PreHashCheckEnabled, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) { + var req contentModerationAPIKeyTestRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{ + APIKeys: req.APIKeys, + BaseURL: req.BaseURL, + Model: req.Model, + TimeoutMS: req.TimeoutMS, + Prompt: req.Prompt, + Images: req.Images, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) GetStatus(c *gin.Context) { + status, err := h.service.GetStatus(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, status) +} + +func (h *ContentModerationHandler) ListLogs(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := service.ContentModerationLogFilter{ + Pagination: pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortOrder: pagination.SortOrderDesc, + }, + Result: c.Query("result"), + Endpoint: c.Query("endpoint"), + Search: c.Query("search"), + } + if raw := strings.TrimSpace(c.Query("group_id")); raw != "" { + groupID, err := strconv.ParseInt(raw, 10, 64) + if err != nil || groupID <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &groupID + } + if raw := strings.TrimSpace(c.Query("from")); raw != "" { + t, _, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid from") + return + } + filter.From = &t + } + if raw := strings.TrimSpace(c.Query("to")); raw != "" { + t, dateOnly, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid to") + return + } + if dateOnly { + t = t.Add(24*time.Hour - time.Nanosecond) + } + filter.To = &t + } + items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize) +} + +func (h *ContentModerationHandler) UnbanUser(c *gin.Context) { + userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + result, err := h.service.UnbanUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) { + var req contentModerationHashRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) { + result, err := h.service.ClearFlaggedInputHashes(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func parseContentModerationDate(raw string) (time.Time, bool, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{}, false, nil + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return t, false, nil + } + t, err := time.Parse("2006-01-02", raw) + return t, err == nil, err +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0cec89aa..7ad51660 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -117,6 +117,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + LoginAgreementEnabled: settings.LoginAgreementEnabled, + LoginAgreementMode: settings.LoginAgreementMode, + LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocumentsToDTO(settings.LoginAgreementDocuments), SMTPHost: settings.SMTPHost, SMTPPort: settings.SMTPPort, SMTPUsername: settings.SMTPUsername, @@ -169,6 +173,16 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, + GitHubOAuthEnabled: settings.GitHubOAuthEnabled, + GitHubOAuthClientID: settings.GitHubOAuthClientID, + GitHubOAuthClientSecretConfigured: settings.GitHubOAuthClientSecretConfigured, + GitHubOAuthRedirectURL: settings.GitHubOAuthRedirectURL, + GitHubOAuthFrontendRedirectURL: settings.GitHubOAuthFrontendRedirectURL, + GoogleOAuthEnabled: settings.GoogleOAuthEnabled, + GoogleOAuthClientID: settings.GoogleOAuthClientID, + GoogleOAuthClientSecretConfigured: settings.GoogleOAuthClientSecretConfigured, + GoogleOAuthRedirectURL: settings.GoogleOAuthRedirectURL, + GoogleOAuthFrontendRedirectURL: settings.GoogleOAuthFrontendRedirectURL, SiteName: settings.SiteName, SiteLogo: settings.SiteLogo, SiteSubtitle: settings.SiteSubtitle, @@ -185,6 +199,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + RiskControlEnabled: settings.RiskControlEnabled, AffiliateRebateRate: settings.AffiliateRebateRate, AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours, AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays, @@ -294,17 +309,50 @@ func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.O return &service.OpenAIFastPolicySettings{Rules: rules} } +func loginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument { + result := make([]dto.LoginAgreementDocument, 0, len(items)) + for _, item := range items { + result = append(result, dto.LoginAgreementDocument{ + ID: item.ID, + Title: item.Title, + ContentMD: item.ContentMD, + }) + } + return result +} + +func loginAgreementDocumentsToService(items []dto.LoginAgreementDocument) []service.LoginAgreementDocument { + result := make([]service.LoginAgreementDocument, 0, len(items)) + for _, item := range items { + title := strings.TrimSpace(item.Title) + content := strings.TrimSpace(item.ContentMD) + if title == "" && content == "" { + continue + } + result = append(result, service.LoginAgreementDocument{ + ID: strings.TrimSpace(item.ID), + Title: title, + ContentMD: content, + }) + } + return result +} + // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - FrontendURL string `json:"frontend_url"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + FrontendURL string `json:"frontend_url"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + LoginAgreementEnabled bool `json:"login_agreement_enabled"` + LoginAgreementMode string `json:"login_agreement_mode"` + LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"` + LoginAgreementDocuments []dto.LoginAgreementDocument `json:"login_agreement_documents"` // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -368,6 +416,17 @@ type UpdateSettingsRequest struct { OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"` OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"` + GitHubOAuthEnabled bool `json:"github_oauth_enabled"` + GitHubOAuthClientID string `json:"github_oauth_client_id"` + GitHubOAuthClientSecret string `json:"github_oauth_client_secret"` + GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"` + GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"` + GoogleOAuthEnabled bool `json:"google_oauth_enabled"` + GoogleOAuthClientID string `json:"google_oauth_client_id"` + GoogleOAuthClientSecret string `json:"google_oauth_client_secret"` + GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"` + GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"` + // OEM设置 SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` @@ -413,6 +472,16 @@ type UpdateSettingsRequest struct { AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"` + AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"` + AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"` + AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"` + AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"` + AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"` + AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"` + AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"` + AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"` + AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"` ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration @@ -497,6 +566,9 @@ type UpdateSettingsRequest struct { // Affiliate (邀请返利) feature switch AffiliateEnabled *bool `json:"affiliate_enabled"` + // 风控中心功能开关 + RiskControlEnabled *bool `json:"risk_control_enabled"` + // OpenAI fast/flex policy (optional, only updated when provided) OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } @@ -633,6 +705,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } } + loginAgreementMode := strings.ToLower(strings.TrimSpace(req.LoginAgreementMode)) + if loginAgreementMode == "" { + loginAgreementMode = strings.ToLower(strings.TrimSpace(previousSettings.LoginAgreementMode)) + } + switch loginAgreementMode { + case "", "modal": + loginAgreementMode = "modal" + case "checkbox": + default: + response.BadRequest(c, "Login agreement mode must be modal or checkbox") + return + } + loginAgreementUpdatedAt := strings.TrimSpace(req.LoginAgreementUpdatedAt) + if loginAgreementUpdatedAt == "" { + loginAgreementUpdatedAt = strings.TrimSpace(previousSettings.LoginAgreementUpdatedAt) + } + loginAgreementDocuments := loginAgreementDocumentsToService(req.LoginAgreementDocuments) + if len(loginAgreementDocuments) == 0 { + loginAgreementDocuments = previousSettings.LoginAgreementDocuments + } + for _, doc := range loginAgreementDocuments { + if strings.TrimSpace(doc.Title) == "" { + response.BadRequest(c, "Login agreement document title is required") + return + } + if len(doc.Title) > 80 { + response.BadRequest(c, "Login agreement document title is too long (max 80 characters)") + return + } + if len(doc.ContentMD) > 200*1024 { + response.BadRequest(c, "Login agreement document content is too large (max 200KB)") + return + } + } + if req.LoginAgreementEnabled && len(loginAgreementDocuments) == 0 { + response.BadRequest(c, "Login agreement documents are required when enabled") + return + } // LinuxDo Connect 参数验证 if req.LinuxDoConnectEnabled { @@ -994,17 +1104,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") return } - if strings.TrimSpace(item.URL) == "" { - response.BadRequest(c, "Custom menu item URL is required") - return - } - if len(item.URL) > maxMenuItemURLLen { - response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") - return - } - if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { - response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") - return + urlTrimmed := strings.TrimSpace(item.URL) + if strings.HasPrefix(urlTrimmed, "md:") { + // Markdown page mode: URL = "md:" + slug := strings.TrimPrefix(urlTrimmed, "md:") + if slug == "" { + response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)") + return + } + } else { + if urlTrimmed == "" { + response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)") + return + } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:") + return + } } if item.Visibility != "user" && item.Visibility != "admin" { response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") @@ -1148,6 +1268,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { FrontendURL: req.FrontendURL, InvitationCodeEnabled: req.InvitationCodeEnabled, TotpEnabled: req.TotpEnabled, + LoginAgreementEnabled: req.LoginAgreementEnabled, + LoginAgreementMode: loginAgreementMode, + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocuments, SMTPHost: req.SMTPHost, SMTPPort: req.SMTPPort, SMTPUsername: req.SMTPUsername, @@ -1200,6 +1324,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath, OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath, OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath, + GitHubOAuthEnabled: req.GitHubOAuthEnabled, + GitHubOAuthClientID: req.GitHubOAuthClientID, + GitHubOAuthClientSecret: req.GitHubOAuthClientSecret, + GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL, + GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL, + GoogleOAuthEnabled: req.GoogleOAuthEnabled, + GoogleOAuthClientID: req.GoogleOAuthClientID, + GoogleOAuthClientSecret: req.GoogleOAuthClientSecret, + GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL, + GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL, SiteName: req.SiteName, SiteLogo: req.SiteLogo, SiteSubtitle: req.SiteSubtitle, @@ -1365,6 +1499,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AffiliateEnabled }(), + RiskControlEnabled: func() bool { + if req.RiskControlEnabled != nil { + return *req.RiskControlEnabled + } + return previousSettings.RiskControlEnabled + }(), } authSourceDefaults := &service.AuthSourceDefaultSettings{ @@ -1396,6 +1536,20 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), }, + GitHub: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultGitHubConcurrency, previousAuthSourceDefaults.GitHub.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind), + }, + Google: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultGoogleConcurrency, previousAuthSourceDefaults.Google.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind), + }, ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), } if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil { @@ -1486,6 +1640,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, TotpEnabled: updatedSettings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + LoginAgreementEnabled: updatedSettings.LoginAgreementEnabled, + LoginAgreementMode: updatedSettings.LoginAgreementMode, + LoginAgreementUpdatedAt: updatedSettings.LoginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocumentsToDTO(updatedSettings.LoginAgreementDocuments), SMTPHost: updatedSettings.SMTPHost, SMTPPort: updatedSettings.SMTPPort, SMTPUsername: updatedSettings.SMTPUsername, @@ -1538,6 +1696,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, + GitHubOAuthEnabled: updatedSettings.GitHubOAuthEnabled, + GitHubOAuthClientID: updatedSettings.GitHubOAuthClientID, + GitHubOAuthClientSecretConfigured: updatedSettings.GitHubOAuthClientSecretConfigured, + GitHubOAuthRedirectURL: updatedSettings.GitHubOAuthRedirectURL, + GitHubOAuthFrontendRedirectURL: updatedSettings.GitHubOAuthFrontendRedirectURL, + GoogleOAuthEnabled: updatedSettings.GoogleOAuthEnabled, + GoogleOAuthClientID: updatedSettings.GoogleOAuthClientID, + GoogleOAuthClientSecretConfigured: updatedSettings.GoogleOAuthClientSecretConfigured, + GoogleOAuthRedirectURL: updatedSettings.GoogleOAuthRedirectURL, + GoogleOAuthFrontendRedirectURL: updatedSettings.GoogleOAuthFrontendRedirectURL, SiteName: updatedSettings.SiteName, SiteLogo: updatedSettings.SiteLogo, SiteSubtitle: updatedSettings.SiteSubtitle, @@ -1616,6 +1784,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, AffiliateEnabled: updatedSettings.AffiliateEnabled, + + RiskControlEnabled: updatedSettings.RiskControlEnabled, } if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { slog.Error("openai_fast_policy_settings_get_failed", "error", err) @@ -1685,6 +1855,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.TotpEnabled != after.TotpEnabled { changed = append(changed, "totp_enabled") } + if before.LoginAgreementEnabled != after.LoginAgreementEnabled { + changed = append(changed, "login_agreement_enabled") + } + if before.LoginAgreementMode != after.LoginAgreementMode { + changed = append(changed, "login_agreement_mode") + } + if before.LoginAgreementUpdatedAt != after.LoginAgreementUpdatedAt { + changed = append(changed, "login_agreement_updated_at") + } + if !equalLoginAgreementDocuments(before.LoginAgreementDocuments, after.LoginAgreementDocuments) { + changed = append(changed, "login_agreement_documents") + } if before.SMTPHost != after.SMTPHost { changed = append(changed, "smtp_host") } @@ -2004,6 +2186,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AffiliateEnabled != after.AffiliateEnabled { changed = append(changed, "affiliate_enabled") } + if before.RiskControlEnabled != after.RiskControlEnabled { + changed = append(changed, "risk_control_enabled") + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } @@ -2027,6 +2212,8 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo}, {name: "oidc", before: before.OIDC, after: after.OIDC}, {name: "wechat", before: before.WeChat, after: after.WeChat}, + {name: "github", before: before.GitHub, after: after.GitHub}, + {name: "google", before: before.Google, after: after.Google}, } for _, field := range fields { if field.before.Balance != field.after.Balance { @@ -2141,6 +2328,16 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["auth_source_default_github_balance"] = authSourceDefaults.GitHub.Balance + data["auth_source_default_github_concurrency"] = authSourceDefaults.GitHub.Concurrency + data["auth_source_default_github_subscriptions"] = authSourceDefaults.GitHub.Subscriptions + data["auth_source_default_github_grant_on_signup"] = authSourceDefaults.GitHub.GrantOnSignup + data["auth_source_default_github_grant_on_first_bind"] = authSourceDefaults.GitHub.GrantOnFirstBind + data["auth_source_default_google_balance"] = authSourceDefaults.Google.Balance + data["auth_source_default_google_concurrency"] = authSourceDefaults.Google.Concurrency + data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions + data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup + data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup return data @@ -2170,6 +2367,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { return true } +func equalLoginAgreementDocuments(a, b []service.LoginAgreementDocument) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].ID != b[i].ID || a[i].Title != b[i].Title || a[i].ContentMD != b[i].ContentMD { + return false + } + } + return true +} + func equalIntSlice(a, b []int) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index a297c56c..db35472e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) { response.Success(c, status) } + +// BatchUpdateConcurrency 批量修改用户并发数 +// POST /api/v1/admin/users/batch-concurrency +type BatchUpdateConcurrencyRequest struct { + UserIDs []int64 `json:"user_ids"` + All bool `json:"all"` + Concurrency int `json:"concurrency"` + Mode string `json:"mode" binding:"required,oneof=set add"` +} + +func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) { + var req BatchUpdateConcurrencyRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.All && len(req.UserIDs) == 0 { + response.BadRequest(c, "user_ids is required unless all=true") + return + } + if len(req.UserIDs) > 500 { + response.BadRequest(c, "user_ids cannot exceed 500") + return + } + + var userIDs []int64 + if req.All { + // Fetch all user IDs via pagination + page := 1 + const pageSize = 500 + for { + users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc") + if err != nil { + response.ErrorFrom(c, err) + return + } + for _, u := range users { + userIDs = append(userIDs, u.ID) + } + if len(users) < pageSize { + break + } + page++ + } + } else { + userIDs = req.UserIDs + } + + if len(userIDs) == 0 { + response.Success(c, gin.H{"affected": 0}) + return + } + + affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"affected": affected}) +} diff --git a/backend/internal/handler/auth_email_oauth.go b/backend/internal/handler/auth_email_oauth.go new file mode 100644 index 00000000..d43acef6 --- /dev/null +++ b/backend/internal/handler/auth_email_oauth.go @@ -0,0 +1,621 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + emailOAuthCookiePath = "/api/v1/auth/oauth" + emailOAuthStateCookieName = "email_oauth_state" + emailOAuthRedirectCookie = "email_oauth_redirect" + emailOAuthProviderCookie = "email_oauth_provider" + emailOAuthAffiliateCookie = "email_oauth_affiliate" + emailOAuthCookieMaxAgeSec = 10 * 60 + emailOAuthDefaultRedirect = "/dashboard" +) + +type emailOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` +} + +type emailOAuthProfile struct { + Subject string + Email string + EmailVerified bool + Username string + DisplayName string + AvatarURL string + Metadata map[string]any +} + +func (h *AuthHandler) GitHubOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "github") } +func (h *AuthHandler) GoogleOAuthStart(c *gin.Context) { h.emailOAuthStart(c, "google") } + +func (h *AuthHandler) GitHubOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "github") } +func (h *AuthHandler) GoogleOAuthCallback(c *gin.Context) { h.emailOAuthCallback(c, "google") } +func (h *AuthHandler) CompleteGitHubOAuthRegistration(c *gin.Context) { + h.completeEmailOAuthRegistration(c, "github") +} +func (h *AuthHandler) CompleteGoogleOAuthRegistration(c *gin.Context) { + h.completeEmailOAuthRegistration(c, "google") +} + +func (h *AuthHandler) emailOAuthStart(c *gin.Context, provider string) { + cfg, err := h.getEmailOAuthConfig(c.Request.Context(), provider) + if err != nil { + response.ErrorFrom(c, err) + return + } + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = emailOAuthDefaultRedirect + } + + secureCookie := isRequestHTTPS(c) + emailOAuthSetCookie(c, emailOAuthStateCookieName, encodeCookieValue(state), secureCookie) + emailOAuthSetCookie(c, emailOAuthRedirectCookie, encodeCookieValue(redirectTo), secureCookie) + emailOAuthSetCookie(c, emailOAuthProviderCookie, encodeCookieValue(provider), secureCookie) + if affCode := strings.TrimSpace(firstNonEmpty(c.Query("aff_code"), c.Query("aff"))); affCode != "" { + emailOAuthSetCookie(c, emailOAuthAffiliateCookie, encodeCookieValue(affCode), secureCookie) + } else { + emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie) + } + + authURL, err := buildEmailOAuthAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + c.Redirect(http.StatusFound, authURL) +} + +func (h *AuthHandler) emailOAuthCallback(c *gin.Context, provider string) { + cfg, cfgErr := h.getEmailOAuthConfig(c.Request.Context(), provider) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = "/auth/oauth/callback" + } + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + emailOAuthClearCookie(c, emailOAuthStateCookieName, secureCookie) + emailOAuthClearCookie(c, emailOAuthRedirectCookie, secureCookie) + emailOAuthClearCookie(c, emailOAuthProviderCookie, secureCookie) + emailOAuthClearCookie(c, emailOAuthAffiliateCookie, secureCookie) + }() + expectedState, err := readCookieDecoded(c, emailOAuthStateCookieName) + if err != nil || expectedState == "" || expectedState != state { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + expectedProvider, _ := readCookieDecoded(c, emailOAuthProviderCookie) + if !strings.EqualFold(strings.TrimSpace(expectedProvider), provider) { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth provider", "") + return + } + redirectTo, _ := readCookieDecoded(c, emailOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = emailOAuthDefaultRedirect + } + + tokenResp, err := exchangeEmailOAuthCode(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(err.Error())) + return + } + profile, err := fetchEmailOAuthProfile(c.Request.Context(), provider, cfg, tokenResp) + if err != nil { + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch verified email", singleLine(err.Error())) + return + } + h.emailOAuthCallbackWithProfile(c, provider, cfg, frontendCallback, redirectTo, profile) +} + +func (h *AuthHandler) emailOAuthCallbackWithProfile( + c *gin.Context, + provider string, + cfg config.EmailOAuthProviderConfig, + frontendCallback string, + redirectTo string, + profile *emailOAuthProfile, +) { + input := service.EmailOAuthIdentityInput{ + ProviderType: provider, + ProviderKey: provider, + ProviderSubject: profile.Subject, + Email: profile.Email, + EmailVerified: profile.EmailVerified, + Username: profile.Username, + DisplayName: profile.DisplayName, + AvatarURL: profile.AvatarURL, + UpstreamMetadata: profile.Metadata, + } + affiliateCode := h.emailOAuthAffiliateCode(c) + if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil { + redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "") + return + } else if shouldCreate { + if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil { + redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode) + if err != nil { + if errors.Is(err, service.ErrOAuthInvitationRequired) { + if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil { + redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "") + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) { + client := h.entClient() + if client == nil { + return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{ + ProviderType: strings.TrimSpace(input.ProviderType), + ProviderKey: strings.TrimSpace(input.ProviderKey), + ProviderSubject: strings.TrimSpace(input.ProviderSubject), + }) + if err != nil { + return false, err + } + email := strings.TrimSpace(strings.ToLower(input.Email)) + if identityUser != nil { + if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) { + return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email") + } + return false, nil + } + if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return true, nil + } + return false, err + } + return false, nil +} + +func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string { + if c == nil { + return "" + } + if code, err := readCookieDecoded(c, emailOAuthAffiliateCookie); err == nil { + return strings.TrimSpace(code) + } + return "" +} + +func (h *AuthHandler) createEmailOAuthRegistrationPendingSession( + c *gin.Context, + provider string, + frontendCallback string, + redirectTo string, + profile *emailOAuthProfile, +) error { + if h == nil || profile == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + setOAuthPendingBrowserCookie(c, browserSessionKey, isRequestHTTPS(c)) + + email := strings.TrimSpace(strings.ToLower(profile.Email)) + username := strings.TrimSpace(profile.Username) + affiliateCode := h.emailOAuthAffiliateCode(c) + upstreamClaims := map[string]any{ + "email": email, + "email_verified": profile.EmailVerified, + "username": username, + "provider": provider, + "provider_key": provider, + "provider_subject": strings.TrimSpace(profile.Subject), + } + if strings.TrimSpace(profile.DisplayName) != "" { + upstreamClaims["suggested_display_name"] = strings.TrimSpace(profile.DisplayName) + } + if strings.TrimSpace(profile.AvatarURL) != "" { + upstreamClaims["suggested_avatar_url"] = strings.TrimSpace(profile.AvatarURL) + } + if affiliateCode != "" { + upstreamClaims["aff_code"] = affiliateCode + } + for key, value := range profile.Metadata { + if _, exists := upstreamClaims[key]; !exists { + upstreamClaims[key] = value + } + } + + invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) + pendingError := "registration_completion_required" + choiceReason := "registration_completion_required" + if invitationRequired { + pendingError = "invitation_required" + choiceReason = "invitation_required" + } + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "error": pendingError, + "choice_reason": choiceReason, + "adoption_required": false, + "create_account_allowed": true, + "existing_account_bindable": false, + "force_email_on_signup": true, + "invitation_required": invitationRequired, + "email": email, + "resolved_email": email, + "provider": provider, + "redirect": redirectTo, + } + if strings.TrimSpace(frontendCallback) != "" { + completionResponse["frontend_callback"] = strings.TrimSpace(frontendCallback) + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: service.PendingAuthIdentityKey{ProviderType: provider, ProviderKey: provider, ProviderSubject: strings.TrimSpace(profile.Subject)}, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +type completeEmailOAuthRequest struct { + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AffCode string `json:"aff_code,omitempty"` +} + +func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider string) { + var req completeEmailOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + affiliateCode := strings.TrimSpace(req.AffCode) + if affiliateCode == "" { + affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code") + } + + tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount( + c.Request.Context(), + strings.TrimSpace(session.ResolvedEmail), + req.Password, + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + tx, err := client.Tx(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err)) + return + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(c.Request.Context(), tx) + sessionForBinding := *session + sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims) + if strings.TrimSpace(req.InvitationCode) != "" { + sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode) + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{}) + if err != nil { + _ = tx.Rollback() + _ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode)) + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil { + _ = tx.Rollback() + _ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode)) + respondPendingOAuthBindingApplyError(c, err) + return + } + if err := h.authService.FinalizeOAuthEmailAccount( + txCtx, + user, + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + affiliateCode, + ); err != nil { + _ = tx.Rollback() + _ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode)) + response.ErrorFrom(c, err) + return + } + if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil { + _ = tx.Rollback() + _ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode)) + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := tx.Commit(); err != nil { + _ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode)) + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func (h *AuthHandler) getEmailOAuthConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetEmailOAuthProviderConfig(ctx, provider) + } + return config.EmailOAuthProviderConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") +} + +func buildEmailOAuthAuthorizeURL(cfg config.EmailOAuthProviderConfig, state string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", cfg.RedirectURL) + q.Set("state", state) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +func exchangeEmailOAuthCode(ctx context.Context, cfg config.EmailOAuthProviderConfig, code string) (*emailOAuthTokenResponse, error) { + resp, err := req.C(). + R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetFormData(map[string]string{ + "grant_type": "authorization_code", + "client_id": cfg.ClientID, + "client_secret": cfg.ClientSecret, + "code": code, + "redirect_uri": cfg.RedirectURL, + }). + Post(cfg.TokenURL) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("token endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024)) + } + var tokenResp emailOAuthTokenResponse + if err := json.Unmarshal(resp.Bytes(), &tokenResp); err != nil { + return nil, err + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, errors.New("missing access_token") + } + return &tokenResp, nil +} + +func fetchEmailOAuthProfile(ctx context.Context, provider string, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse) (*emailOAuthProfile, error) { + resp, err := req.C(). + R(). + SetContext(ctx). + SetBearerAuthToken(token.AccessToken). + SetHeader("Accept", "application/json"). + Get(cfg.UserInfoURL) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("userinfo endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024)) + } + switch strings.ToLower(strings.TrimSpace(provider)) { + case "github": + return parseGitHubOAuthProfile(ctx, cfg, token, resp.String()) + case "google": + return parseGoogleOAuthProfile(resp.String()) + default: + return nil, errors.New("unsupported oauth provider") + } +} + +func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderConfig, token *emailOAuthTokenResponse, body string) (*emailOAuthProfile, error) { + subject := strings.TrimSpace(gjson.Get(body, "id").String()) + if subject == "" { + return nil, errors.New("github user id is missing") + } + email := "" + emailsURL := strings.TrimSpace(cfg.EmailsURL) + if emailsURL == "" { + return nil, errors.New("github verified email is missing") + } + verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken) + if err != nil { + return nil, err + } + email = verifiedEmail + if email == "" { + return nil, errors.New("github verified email is missing") + } + login := strings.TrimSpace(gjson.Get(body, "login").String()) + name := strings.TrimSpace(gjson.Get(body, "name").String()) + return &emailOAuthProfile{ + Subject: subject, + Email: email, + EmailVerified: true, + Username: firstNonEmpty(login, name, "github_"+subject), + DisplayName: firstNonEmpty(name, login), + AvatarURL: strings.TrimSpace(gjson.Get(body, "avatar_url").String()), + Metadata: map[string]any{ + "login": login, + }, + }, nil +} + +func fetchGitHubPrimaryVerifiedEmail(ctx context.Context, emailsURL string, accessToken string) (string, error) { + resp, err := req.C(). + R(). + SetContext(ctx). + SetBearerAuthToken(accessToken). + SetHeader("Accept", "application/json"). + Get(emailsURL) + if err != nil { + return "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("github emails endpoint status %d: %s", resp.StatusCode, truncateLogValue(resp.String(), 1024)) + } + items := gjson.Parse(resp.String()).Array() + for _, item := range items { + if item.Get("primary").Bool() && item.Get("verified").Bool() { + if email := strings.TrimSpace(item.Get("email").String()); email != "" { + return email, nil + } + } + } + for _, item := range items { + if item.Get("verified").Bool() { + if email := strings.TrimSpace(item.Get("email").String()); email != "" { + return email, nil + } + } + } + return "", errors.New("github verified email is missing") +} + +func parseGoogleOAuthProfile(body string) (*emailOAuthProfile, error) { + subject := strings.TrimSpace(gjson.Get(body, "sub").String()) + email := strings.TrimSpace(gjson.Get(body, "email").String()) + verified := gjson.Get(body, "email_verified").Bool() + if subject == "" { + return nil, errors.New("google subject is missing") + } + if email == "" || !verified { + return nil, errors.New("google verified email is missing") + } + name := strings.TrimSpace(gjson.Get(body, "name").String()) + return &emailOAuthProfile{ + Subject: subject, + Email: email, + EmailVerified: true, + Username: firstNonEmpty(strings.TrimSpace(gjson.Get(body, "given_name").String()), name, email), + DisplayName: name, + AvatarURL: strings.TrimSpace(gjson.Get(body, "picture").String()), + Metadata: map[string]any{ + "email_verified": true, + }, + }, nil +} + +func emailOAuthSetCookie(c *gin.Context, name, value string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: emailOAuthCookiePath, + MaxAge: emailOAuthCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func emailOAuthClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: emailOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_email_oauth_test.go b/backend/internal/handler/auth_email_oauth_test.go new file mode 100644 index 00000000..ecf71c5a --- /dev/null +++ b/backend/internal/handler/auth_email_oauth_test.go @@ -0,0 +1,414 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + ctx := context.Background() + + state := "github-oauth-state" + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil) + req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)}) + req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")}) + req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")}) + c.Request = req + + profile := &emailOAuthProfile{ + Subject: "github-123", + Email: "fresh@example.com", + EmailVerified: true, + Username: "fresh", + DisplayName: "Fresh User", + AvatarURL: "https://cdn.example/fresh.png", + Metadata: map[string]any{ + "login": "fresh", + }, + } + handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{ + Enabled: true, + ClientID: "github-client", + ClientSecret: "github-secret", + RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback", + FrontendRedirectURL: "/auth/oauth/callback", + }, "/auth/oauth/callback", "/dashboard", profile) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "/auth/oauth/callback") + require.NotContains(t, location, "access_token=") + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + session, err := client.PendingAuthSession.Query().Only(ctx) + require.NoError(t, err) + require.Equal(t, "github", session.ProviderType) + require.Equal(t, "github", session.ProviderKey) + require.Equal(t, "github-123", session.ProviderSubject) + require.Equal(t, "fresh@example.com", session.ResolvedEmail) + require.Equal(t, "/dashboard", session.RedirectTo) + require.Nil(t, session.TargetUserID) + + completion, ok := readCompletionResponse(session.LocalFlowState) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "invitation_required", completion["error"]) + require.Equal(t, true, completion["invitation_required"]) + require.Equal(t, "fresh@example.com", completion["email"]) + require.Equal(t, "fresh@example.com", completion["resolved_email"]) + require.Equal(t, true, completion["create_account_allowed"]) + + require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName)) +} + +func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("existing@example.com"). + SetUsername("existing"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil) + + handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{ + Enabled: true, + ClientID: "google-client", + ClientSecret: "google-secret", + RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback", + FrontendRedirectURL: "/auth/oauth/callback", + }, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{ + Subject: "google-123", + Email: "existing@example.com", + EmailVerified: true, + Username: "existing", + }) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "access_token=") + require.Contains(t, location, "redirect=%252Fdashboard") + + sessionCount, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, sessionCount) + + identityCount, err := client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("google"), + authidentity.ProviderSubjectEQ("google-123"), + ).Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + _ = user +} + +func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) { + affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001}) + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyAffiliateEnabled: "true", + }, + affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService { + return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil) + }, + }) + ctx := context.Background() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil) + req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")}) + c.Request = req + + handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{ + Enabled: true, + ClientID: "github-client", + ClientSecret: "github-secret", + RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback", + FrontendRedirectURL: "/auth/oauth/callback", + }, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{ + Subject: "github-aff-user", + Email: "aff-user@example.com", + EmailVerified: true, + Username: "aff-user", + }) + + require.Equal(t, http.StatusFound, recorder.Code) + require.NotContains(t, recorder.Header().Get("Location"), "access_token=") + userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + require.Empty(t, affiliateRepo.ensureUserIDs) + require.Empty(t, affiliateRepo.bindCalls) + + session, err := client.PendingAuthSession.Query().Only(ctx) + require.NoError(t, err) + require.Equal(t, "aff-user@example.com", session.ResolvedEmail) + require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")) + + completion, ok := readCompletionResponse(session.LocalFlowState) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "registration_completion_required", completion["error"]) + require.Equal(t, false, completion["invitation_required"]) + require.Equal(t, true, completion["create_account_allowed"]) + require.Equal(t, true, completion["force_email_on_signup"]) + require.Equal(t, "aff-user@example.com", completion["resolved_email"]) +} + +func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) { + affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002}) + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + invitationEnabled: true, + settingValues: map[string]string{ + service.SettingKeyAffiliateEnabled: "true", + }, + affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService { + return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil) + }, + }) + ctx := context.Background() + invitation, err := client.RedeemCode.Create(). + SetCode("INVITE456"). + SetType(service.RedeemTypeInvitation). + SetStatus(service.StatusUnused). + SetValue(0). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("email-oauth-aff-session-token"). + SetIntent(oauthIntentLogin). + SetProviderType("google"). + SetProviderKey("google"). + SetProviderSubject("google-aff-user"). + SetResolvedEmail("pending-aff@example.com"). + SetRedirectTo("/dashboard"). + SetBrowserSessionKey("browser-aff-key"). + SetUpstreamIdentityClaims(map[string]any{ + "email": "pending-aff@example.com", + "email_verified": true, + "username": "pending-aff", + "provider": "google", + "provider_key": "google", + "provider_subject": "google-aff-user", + "aff_code": "AFF456", + }). + SetLocalFlowState(map[string]any{ + "step": oauthPendingChoiceStep, + "error": "invitation_required", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")}) + c.Request = req + + handler.completeEmailOAuthRegistration(c, "google") + + require.Equal(t, http.StatusOK, recorder.Code) + user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx) + require.NoError(t, err) + require.NotEmpty(t, user.PasswordHash) + require.NotEqual(t, "secret-123", user.PasswordHash) + tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, tamperedCount) + require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls) + storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedInvitation.UsedBy) + require.Equal(t, user.ID, *storedInvitation.UsedBy) +} + +func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("email-oauth-password-session-token"). + SetIntent(oauthIntentLogin). + SetProviderType("github"). + SetProviderKey("github"). + SetProviderSubject("github-password-user"). + SetResolvedEmail("password-required@example.com"). + SetRedirectTo("/dashboard"). + SetBrowserSessionKey("browser-password-key"). + SetUpstreamIdentityClaims(map[string]any{ + "email": "password-required@example.com", + "email_verified": true, + "username": "password-required", + "provider": "github", + "provider_key": "github", + "provider_subject": "github-password-user", + }). + SetLocalFlowState(map[string]any{ + "step": oauthPendingChoiceStep, + "error": "registration_completion_required", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")}) + c.Request = req + + handler.completeEmailOAuthRegistration(c, "github") + + require.Equal(t, http.StatusBadRequest, recorder.Code) + userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) +} + +func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) { + emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "missing scope", http.StatusForbidden) + })) + t.Cleanup(emailServer.Close) + + profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{ + EmailsURL: emailServer.URL, + }, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`) + + require.Error(t, err) + require.Nil(t, profile) + require.Contains(t, err.Error(), "github emails endpoint status 403") +} + +type oauthEmailAffiliateBindCall struct { + userID int64 + inviterID int64 +} + +type oauthEmailAffiliateRepoStub struct { + codeOwners map[string]int64 + ensureUserIDs []int64 + bindCalls []oauthEmailAffiliateBindCall +} + +func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub { + return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners} +} + +func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) { + r.ensureUserIDs = append(r.ensureUserIDs, userID) + return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil +} + +func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) { + userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))] + if !ok { + return nil, service.ErrAffiliateProfileNotFound + } + return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil +} + +func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) { + r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID}) + return true, nil +} + +func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) { + panic("unexpected AccrueQuota call") +} + +func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) { + panic("unexpected GetAccruedRebateFromInvitee call") +} + +func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) { + panic("unexpected ThawFrozenQuota call") +} + +func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) { + panic("unexpected TransferQuotaToBalance call") +} + +func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) { + panic("unexpected ListInvitees call") +} + +func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error { + panic("unexpected UpdateUserAffCode call") +} + +func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) { + panic("unexpected ResetUserAffCode call") +} + +func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error { + panic("unexpected SetUserRebateRate call") +} + +func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error { + panic("unexpected BatchSetUserRebateRate call") +} + +func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) { + panic("unexpected ListUsersWithCustomSettings call") +} + +func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) { + panic("unexpected ListAffiliateInviteRecords call") +} + +func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) { + panic("unexpected ListAffiliateRebateRecords call") +} + +func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) { + panic("unexpected ListAffiliateTransferRecords call") +} + +func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) { + panic("unexpected GetAffiliateUserOverview call") +} + +func findSetCookieValue(cookies []*http.Cookie, name string) string { + for _, cookie := range cookies { + if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 { + return cookie.Value + } + } + return "" +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index ffe9ff5f..584e5751 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2121,6 +2121,8 @@ type oauthPendingFlowTestHandlerOptions struct { emailCache service.EmailCache settingValues map[string]string defaultSubAssigner service.DefaultSubscriptionAssigner + affiliateService *service.AffiliateService + affiliateFactory func(*dbent.Client, *service.SettingService) *service.AffiliateService totpCache service.TotpCache totpEncryptor service.SecretEncryptor userRepoOptions oauthPendingFlowUserRepoOptions @@ -2160,6 +2162,21 @@ CREATE TABLE IF NOT EXISTS user_avatars ( updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP )`) require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_affiliates ( + user_id INTEGER PRIMARY KEY, + aff_code TEXT NOT NULL UNIQUE, + aff_code_custom BOOLEAN NOT NULL DEFAULT false, + aff_rebate_rate_percent REAL NULL, + inviter_id INTEGER NULL, + aff_count INTEGER NOT NULL DEFAULT 0, + aff_quota REAL NOT NULL DEFAULT 0, + aff_frozen_quota REAL NOT NULL DEFAULT 0, + aff_history_quota REAL NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) @@ -2177,14 +2194,19 @@ CREATE TABLE IF NOT EXISTS user_avatars ( }, } settingValues := map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), - service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", } for key, value := range options.settingValues { settingValues[key] = value } settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) + affiliateService := options.affiliateService + if affiliateService == nil && options.affiliateFactory != nil { + affiliateService = options.affiliateFactory(client, settingSvc) + } userRepo := &oauthPendingFlowUserRepo{ client: client, options: options.userRepoOptions, @@ -2210,7 +2232,7 @@ CREATE TABLE IF NOT EXISTS user_avatars ( nil, nil, options.defaultSubAssigner, - nil, + affiliateService, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService @@ -2798,6 +2820,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int panic("unexpected UpdateConcurrency call") } +func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + panic("unexpected BatchSetConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + panic("unexpected BatchAddConcurrency call") +} + func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { return map[int64]*time.Time{}, nil } diff --git a/backend/internal/handler/content_moderation_helper.go b/backend/internal/handler/content_moderation_helper.go new file mode 100644 index 00000000..af6dbd8e --- /dev/null +++ b/backend/internal/handler/content_moderation_helper.go @@ -0,0 +1,130 @@ +package handler + +import ( + "context" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func contentModerationStatus(decision *service.ContentModerationDecision) int { + if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 { + return http.StatusForbidden + } + return decision.StatusCode +} + +func contentModerationErrorCode(decision *service.ContentModerationDecision) string { + return "content_policy_violation" +} + +func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if svc == nil || c == nil || c.Request == nil { + return nil + } + input := buildContentModerationInput(c, apiKey, subject, protocol, model, body) + if reqLog != nil { + reqLog.Info("content_moderation.gateway_check_start", + zap.String("request_id", input.RequestID), + zap.Int64("user_id", input.UserID), + zap.Int64("api_key_id", input.APIKeyID), + zap.String("api_key_name", input.APIKeyName), + zap.Int64p("group_id", input.GroupID), + zap.String("group_name", input.GroupName), + zap.String("endpoint", input.Endpoint), + zap.String("provider", input.Provider), + zap.String("protocol", input.Protocol), + zap.String("model", input.Model), + zap.Int("body_bytes", len(body)), + ) + } + decision, err := svc.Check(c.Request.Context(), input) + if err != nil { + if reqLog != nil { + reqLog.Warn("content_moderation.check_failed", zap.Error(err)) + } + return nil + } + if reqLog != nil && decision != nil { + reqLog.Info("content_moderation.gateway_check_done", + zap.String("request_id", input.RequestID), + zap.Bool("allowed", decision.Allowed), + zap.Bool("blocked", decision.Blocked), + zap.Bool("flagged", decision.Flagged), + zap.String("action", decision.Action), + zap.Int("status_code", decision.StatusCode), + zap.String("highest_category", decision.HighestCategory), + zap.Float64("highest_score", decision.HighestScore), + ) + } + return decision +} + +func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput { + input := service.ContentModerationCheckInput{ + RequestID: contentModerationRequestID(c.Request.Context()), + UserID: subject.UserID, + Endpoint: GetInboundEndpoint(c), + Provider: contentModerationProvider(apiKey), + Model: strings.TrimSpace(model), + Protocol: protocol, + Body: body, + } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { + input.Provider = strings.TrimSpace(forcedPlatform) + } + if apiKey != nil { + input.APIKeyID = apiKey.ID + input.APIKeyName = apiKey.Name + if apiKey.User != nil { + input.UserEmail = apiKey.User.Email + } + if apiKey.GroupID != nil { + groupID := *apiKey.GroupID + input.GroupID = &groupID + } + if apiKey.Group != nil { + input.GroupName = apiKey.Group.Name + } + } + if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil { + input.Endpoint = c.Request.URL.Path + } + return input +} + +func contentModerationProvider(apiKey *service.APIKey) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.Platform) +} + +func contentModerationRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok { + return strings.TrimSpace(requestID) + } + return "" +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 0bc834fe..2d4cefa1 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -11,6 +11,7 @@ type CustomMenuItem struct { Label string `json:"label"` IconSVG string `json:"icon_svg"` URL string `json:"url"` + PageSlug string `json:"page_slug,omitempty"` Visibility string `json:"visibility"` // "user" or "admin" SortOrder int `json:"sort_order"` } @@ -24,15 +25,19 @@ type CustomEndpoint struct { // SystemSettings represents the admin settings API response payload. type SystemSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - FrontendURL string `json:"frontend_url"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + FrontendURL string `json:"frontend_url"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 + LoginAgreementEnabled bool `json:"login_agreement_enabled"` + LoginAgreementMode string `json:"login_agreement_mode"` + LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"` + LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"` SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` @@ -91,6 +96,17 @@ type SystemSettings struct { OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"` OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"` + GitHubOAuthEnabled bool `json:"github_oauth_enabled"` + GitHubOAuthClientID string `json:"github_oauth_client_id"` + GitHubOAuthClientSecretConfigured bool `json:"github_oauth_client_secret_configured"` + GitHubOAuthRedirectURL string `json:"github_oauth_redirect_url"` + GitHubOAuthFrontendRedirectURL string `json:"github_oauth_frontend_redirect_url"` + GoogleOAuthEnabled bool `json:"google_oauth_enabled"` + GoogleOAuthClientID string `json:"google_oauth_client_id"` + GoogleOAuthClientSecretConfigured bool `json:"google_oauth_client_secret_configured"` + GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url"` + GoogleOAuthFrontendRedirectURL string `json:"google_oauth_frontend_redirect_url"` + SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` SiteSubtitle string `json:"site_subtitle"` @@ -197,6 +213,9 @@ type SystemSettings struct { // Available Channels feature switch (user-facing aggregate view) AvailableChannelsEnabled bool `json:"available_channels_enabled"` + // 风控中心功能开关 + RiskControlEnabled bool `json:"risk_control_enabled"` + // Affiliate (邀请返利) feature switch AffiliateEnabled bool `json:"affiliate_enabled"` @@ -210,45 +229,52 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` - RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - TableDefaultPageSize int `json:"table_default_page_size"` - TablePageSizeOptions []int `json:"table_page_size_options"` - CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` - WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` - WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` - WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` - OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` - OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` - SoraClientEnabled bool `json:"sora_client_enabled"` - BackendModeEnabled bool `json:"backend_mode_enabled"` - PaymentEnabled bool `json:"payment_enabled"` - Version string `json:"version"` - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + LoginAgreementEnabled bool `json:"login_agreement_enabled"` + LoginAgreementMode string `json:"login_agreement_mode"` + LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"` + LoginAgreementRevision string `json:"login_agreement_revision"` + LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` + OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` + OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` + GitHubOAuthEnabled bool `json:"github_oauth_enabled"` + GoogleOAuthEnabled bool `json:"google_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` + PaymentEnabled bool `json:"payment_enabled"` + Version string `json:"version"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` @@ -256,6 +282,14 @@ type PublicSettings struct { AvailableChannelsEnabled bool `json:"available_channels_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"` + + RiskControlEnabled bool `json:"risk_control_enabled"` +} + +type LoginAgreementDocument struct { + ID string `json:"id"` + Title string `json:"title"` + ContentMD string `json:"content_md"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index d12c2941..238bc892 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -46,6 +46,7 @@ type GatewayHandler struct { apiKeyService *service.APIKeyService usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService concurrencyHelper *ConcurrencyHelper userMsgQueueHelper *UserMsgQueueHelper requestEventBus *service.RequestEventBus @@ -68,6 +69,7 @@ func NewGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, @@ -103,6 +105,7 @@ func NewGatewayHandler( apiKeyService: apiKeyService, usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), userMsgQueueHelper: umqHelper, requestEventBus: requestEventBus, @@ -215,6 +218,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 4290e54b..c6b73190 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 683cf2b7..a97f572d 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2a34e3f0..90ebe9ec 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked { + googleError(c, contentModerationStatus(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 906ab95f..1bfe9855 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -33,6 +33,7 @@ type AdminHandlers struct { Channel *admin.ChannelHandler ChannelMonitor *admin.ChannelMonitorHandler ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler + ContentModeration *admin.ContentModerationHandler Payment *admin.PaymentHandler Windsurf *admin.WindsurfHandler Affiliate *admin.AffiliateHandler diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 06ab9d52..de384710 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3997a0ee..6b07b7ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -27,15 +27,16 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - apiKeyService *service.APIKeyService - usageRecordWorkerPool *service.UsageRecordWorkerPool - errorPassthroughService *service.ErrorPassthroughService - concurrencyHelper *ConcurrencyHelper - imageLimiter *imageConcurrencyLimiter - maxAccountSwitches int - cfg *config.Config + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool + errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService + concurrencyHelper *ConcurrencyHelper + imageLimiter *imageConcurrencyLimiter + maxAccountSwitches int + cfg *config.Config } func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { @@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - apiKeyService: apiKeyService, - usageRecordWorkerPool: usageRecordWorkerPool, - errorPassthroughService: errorPassthroughService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - imageLimiter: &imageConcurrencyLimiter{}, - maxAccountSwitches: maxAccountSwitches, - cfg: cfg, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, + errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + imageLimiter: &imageConcurrencyLimiter{}, + maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body) if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) @@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) @@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message) + return + } + if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage()) return @@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { hooks := &service.OpenAIWSIngressHooks{ InitialRequestModel: reqModel, + BeforeRequest: func(turn int, payload []byte, originalModel string) error { + if turn == 1 { + return nil + } + if !gjson.ValidBytes(payload) { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + model := strings.TrimSpace(originalModel) + if model == "" { + model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + } + if model == "" { + model = reqModel + } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) + } + return nil + }, BeforeTurn: func(turn int) error { if turn == 1 { return nil @@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s _ = conn.CloseNow() } +func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { + if conn == nil || decision == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + message := strings.TrimSpace(decision.Message) + if message == "" { + message = "content moderation blocked this request" + } + payload, err := json.Marshal(gin.H{ + "event_id": "evt_content_moderation_blocked", + "type": "error", + "error": gin.H{ + "type": "invalid_request_error", + "code": contentModerationErrorCode(decision), + "message": message, + }, + }) + if err != nil { + payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`) + } + writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + _ = conn.Write(writeCtx, coderws.MessageText, payload) +} + func summarizeWSCloseErrorForLog(err error) (string, string) { if err == nil { return "-", "-" diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index c560350e..6bddbce9 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" coderws "github.com/coder/websocket" @@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") } +type contentModerationHandlerSettingRepo struct { + values map[string]string +} + +func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + if value, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: value}, nil + } + return nil, service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if value, ok := r.values[key]; ok { + return value, nil + } + return "", service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error { + if r.values == nil { + r.values = map[string]string{} + } + r.values[key] = value + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := map[string]string{} + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + if r.values == nil { + r.values = map[string]string{} + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.values, key) + return nil +} + +type contentModerationHandlerTestRepo struct { + logs []service.ContentModerationLog +} + +func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log != nil { + r.logs = append(r.logs, *log) + } + return nil +} + +func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + return 0, nil +} + +func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + return &service.ContentModerationCleanupResult{}, nil +} + +func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) { + gin.SetMode(gin.TestMode) + + moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + _, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`)) + })) + defer moderationServer.Close() + + cfg := &service.ContentModerationConfig{ + Enabled: true, + Mode: service.ContentModerationModePreBlock, + BaseURL: moderationServer.URL, + Model: "omni-moderation-latest", + APIKeys: []string{"sk-test"}, + SampleRate: 100, + AllGroups: true, + BlockMessage: "内容审计测试阻断", + } + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationHandlerTestRepo{} + settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{ + service.SettingKeyRiskControlEnabled: "true", + service.SettingKeyContentModerationConfig: string(rawCfg), + }} + moderationSvc := service.NewContentModerationService( + settingRepo, + repo, + nil, + nil, + nil, + nil, + nil, + ) + decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{ + UserID: 1, + Endpoint: "/v1/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: service.ContentModerationProtocolOpenAIResponses, + Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + repo.logs = nil + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + contentModerationService: moderationSvc, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second), + } + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{ + "type":"response.create", + "model":"gpt-5.5", + "input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}] + }`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, readErr := clientConn.Read(readCtx) + cancelRead() + if readErr == nil { + require.Contains(t, string(payload), "content_policy_violation") + require.Contains(t, string(payload), "内容审计测试阻断") + } else { + var closeErr coderws.CloseError + require.ErrorAs(t, readErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, closeErr.Reason, "内容审计测试阻断") + } + require.Len(t, repo.logs, 1) + require.True(t, repo.logs[0].Flagged) + require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action) + require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt) +} + func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`, diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index eba701f1..08a6b6e8 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted) if !acquired { return diff --git a/backend/internal/handler/page_handler.go b/backend/internal/handler/page_handler.go new file mode 100644 index 00000000..7d4d5078 --- /dev/null +++ b/backend/internal/handler/page_handler.go @@ -0,0 +1,283 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + +const maxPageFileSize = 1 << 20 // 1MB + +type PageHandler struct { + pagesDir string + settingService *service.SettingService +} + +func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler { + pagesDir := filepath.Join(dataDir, "pages") + _ = os.MkdirAll(pagesDir, 0755) + return &PageHandler{pagesDir: pagesDir, settingService: settingService} +} + +// GetPageContent serves raw markdown content for a given slug. +// GET /api/v1/pages/:slug +func (h *PageHandler) GetPageContent(c *gin.Context) { + slug := c.Param("slug") + if !validSlugPattern.MatchString(slug) || len(slug) > 64 { + response.BadRequest(c, "Invalid page slug") + return + } + + // Visibility check: slug must be configured in custom_menu_items + // and the user must have permission based on visibility setting + if !h.checkSlugVisibility(c, slug) { + c.JSON(http.StatusNotFound, gin.H{"error": "page not found"}) + return + } + + filePath := filepath.Join(h.pagesDir, slug+".md") + cleaned := filepath.Clean(filePath) + if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) { + response.BadRequest(c, "Invalid page slug") + return + } + + info, err := os.Stat(cleaned) + if err != nil || info.IsDir() { + c.JSON(http.StatusNotFound, gin.H{"error": "page not found"}) + return + } + if info.Size() > maxPageFileSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"}) + return + } + + content, err := os.ReadFile(cleaned) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"}) + return + } + + c.Data(http.StatusOK, "text/markdown; charset=utf-8", content) +} + +// ListPages returns available page slugs. +// GET /api/v1/pages +func (h *PageHandler) ListPages(c *gin.Context) { + entries, err := os.ReadDir(h.pagesDir) + if err != nil { + response.Success(c, []string{}) + return + } + + slugs := make([]string, 0, len(entries)) + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if strings.HasSuffix(name, ".md") { + slugs = append(slugs, strings.TrimSuffix(name, ".md")) + } + } + response.Success(c, slugs) +} + +// ServePageImage serves images from data/pages/{slug}/ directory. +// GET /api/v1/pages/:slug/images/*filename +// No JWT required (browser img tags can't carry tokens), but visibility is checked. +func (h *PageHandler) ServePageImage(c *gin.Context) { + slug := c.Param("slug") + filename := c.Param("filename") + filename = strings.TrimPrefix(filename, "/") + + if !validSlugPattern.MatchString(slug) || len(slug) > 64 { + c.Status(http.StatusNotFound) + return + } + + if !h.checkImageSlugVisibility(c, slug) { + c.Status(http.StatusNotFound) + return + } + + imagesDir := filepath.Join(h.pagesDir, slug) + cleaned, ok := resolvePageImagePath(h.pagesDir, imagesDir, filename) + if !ok { + c.Status(http.StatusNotFound) + return + } + + info, err := os.Stat(cleaned) + if err != nil || info.IsDir() { + c.Status(http.StatusNotFound) + return + } + + c.File(cleaned) +} + +func resolvePageImagePath(pagesDir, imagesDir, filename string) (string, bool) { + relPath, ok := cleanPageImageRelativePath(filename) + if !ok { + return "", false + } + + cleanedPagesDir := filepath.Clean(pagesDir) + cleanedImagesDir := filepath.Clean(imagesDir) + cleanedTarget := filepath.Clean(filepath.Join(cleanedImagesDir, relPath)) + if !isPathWithinBase(cleanedTarget, cleanedImagesDir) { + return "", false + } + + realPagesDir, err := filepath.EvalSymlinks(cleanedPagesDir) + if err != nil { + return "", false + } + realImagesDir, err := filepath.EvalSymlinks(cleanedImagesDir) + if err != nil || !isPathWithinBase(realImagesDir, realPagesDir) { + return "", false + } + realTarget, err := filepath.EvalSymlinks(cleanedTarget) + if err != nil || !isPathWithinBase(realTarget, realImagesDir) { + return "", false + } + return realTarget, true +} + +func cleanPageImageRelativePath(filename string) (string, bool) { + if filename == "" { + return "", false + } + if strings.HasPrefix(filename, "/") { + return "", false + } + decoded, err := url.PathUnescape(filename) + if err != nil { + return "", false + } + if decoded == "" || strings.HasPrefix(decoded, "/") || strings.Contains(decoded, "\\") || strings.ContainsRune(decoded, 0) { + return "", false + } + + parts := make([]string, 0) + for _, part := range strings.Split(decoded, "/") { + switch part { + case "", ".": + continue + case "..": + return "", false + default: + parts = append(parts, part) + } + } + if len(parts) == 0 { + return "", false + } + + relPath := filepath.Join(parts...) + if filepath.IsAbs(relPath) || filepath.VolumeName(relPath) != "" { + return "", false + } + return relPath, true +} + +func isPathWithinBase(path, base string) bool { + rel, err := filepath.Rel(filepath.Clean(base), filepath.Clean(path)) + if err != nil { + return false + } + return rel != "." && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found). +func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) { + if h.settingService == nil { + return "", false + } + + raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context()) + if raw == "" || raw == "[]" { + return "", false + } + + var items []struct { + URL string `json:"url"` + PageSlug string `json:"page_slug"` + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return "", false + } + + for _, item := range items { + itemSlug := item.PageSlug + if itemSlug == "" && strings.HasPrefix(item.URL, "md:") { + itemSlug = strings.TrimPrefix(item.URL, "md:") + } + if itemSlug == slug { + return item.Visibility, true + } + } + return "", false +} + +// checkSlugVisibility verifies the slug is configured in custom_menu_items +// and the authenticated user has permission to view it. +func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool { + visibility, found := h.findSlugVisibility(c, slug) + if !found { + return false + } + if visibility == "admin" { + role, _ := middleware2.GetUserRoleFromContext(c) + return role == "admin" + } + return true +} + +// checkImageSlugVisibility checks visibility for image requests (no JWT available). +// Only allows user-visible pages; admin-only pages are blocked. +func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool { + visibility, found := h.findSlugVisibility(c, slug) + if !found { + return false + } + return visibility != "admin" +} + +// RegisterPageRoutes registers page routes on a router group. +func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) { + h := NewPageHandler(dataDir, settingService) + + // Authenticated page content (JWT required + visibility check) + pages := v1.Group("/pages") + pages.Use(jwtAuth) + { + pages.GET("/:slug", h.GetPageContent) + } + + // Images: no JWT (browser img tags can't carry tokens), visibility check in handler + pageImages := v1.Group("/pages") + { + pageImages.GET("/:slug/images/*filename", h.ServePageImage) + } + + // Admin-only: list all available pages + adminPages := v1.Group("/pages") + adminPages.Use(adminAuth) + { + adminPages.GET("", h.ListPages) + } +} diff --git a/backend/internal/handler/page_handler_test.go b/backend/internal/handler/page_handler_test.go new file mode 100644 index 00000000..0a9f0d96 --- /dev/null +++ b/backend/internal/handler/page_handler_test.go @@ -0,0 +1,102 @@ +package handler + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCleanPageImageRelativePath(t *testing.T) { + tests := []struct { + name string + in string + want string + ok bool + }{ + {name: "single filename", in: "logo.png", want: "logo.png", ok: true}, + {name: "nested path", in: "images/logo.png", want: filepath.Join("images", "logo.png"), ok: true}, + {name: "dot prefix", in: "./logo.png", want: "logo.png", ok: true}, + {name: "url escaped slash", in: "images%2Flogo.png", want: filepath.Join("images", "logo.png"), ok: true}, + {name: "parent traversal", in: "../secret.png", ok: false}, + {name: "encoded parent traversal", in: "%2e%2e/secret.png", ok: false}, + {name: "backslash traversal", in: `images\secret.png`, ok: false}, + {name: "absolute path", in: "/etc/passwd", ok: false}, + {name: "encoded absolute path", in: "%2fetc/passwd", ok: false}, + {name: "encoded nul byte", in: "logo.png%00", ok: false}, + {name: "invalid escape", in: "logo.png%zz", ok: false}, + {name: "empty path", in: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := cleanPageImageRelativePath(tt.in) + if ok != tt.ok { + t.Fatalf("ok = %v, want %v", ok, tt.ok) + } + if got != tt.want { + t.Fatalf("path = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolvePageImagePath(t *testing.T) { + root := t.TempDir() + pagesDir := filepath.Join(root, "pages") + base := filepath.Join(pagesDir, "guide") + if err := os.MkdirAll(filepath.Join(base, "images"), 0755); err != nil { + t.Fatalf("create images dir: %v", err) + } + if err := os.WriteFile(filepath.Join(base, "logo.png"), []byte("fake"), 0644); err != nil { + t.Fatalf("create direct image: %v", err) + } + if err := os.WriteFile(filepath.Join(base, "images", "logo.png"), []byte("fake"), 0644); err != nil { + t.Fatalf("create image: %v", err) + } + + got, ok := resolvePageImagePath(pagesDir, base, "logo.png") + if !ok { + t.Fatal("expected direct image path to be accepted") + } + want := filepath.Join(base, "logo.png") + if got != want { + t.Fatalf("path = %q, want %q", got, want) + } + + got, ok = resolvePageImagePath(pagesDir, base, "images/logo.png") + if !ok { + t.Fatal("expected nested image path to be accepted") + } + want = filepath.Join(base, "images", "logo.png") + if got != want { + t.Fatalf("path = %q, want %q", got, want) + } + + if got, ok := resolvePageImagePath(pagesDir, base, "../guide.md"); ok { + t.Fatalf("expected traversal to be rejected, got %q", got) + } +} + +func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) { + root := t.TempDir() + pagesDir := filepath.Join(root, "pages") + base := filepath.Join(pagesDir, "guide") + outside := filepath.Join(root, "outside") + + if err := os.MkdirAll(base, 0755); err != nil { + t.Fatalf("create page dir: %v", err) + } + if err := os.MkdirAll(outside, 0755); err != nil { + t.Fatalf("create outside dir: %v", err) + } + if err := os.WriteFile(filepath.Join(outside, "secret.png"), []byte("secret"), 0644); err != nil { + t.Fatalf("create outside file: %v", err) + } + if err := os.Symlink(outside, filepath.Join(base, "images")); err != nil { + t.Skipf("symlink not supported: %v", err) + } + + if got, ok := resolvePageImagePath(pagesDir, base, "images/secret.png"); ok { + t.Fatalf("expected symlink escape to be rejected, got %q", got) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 22f2aa15..6c389e3d 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -40,6 +40,11 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PasswordResetEnabled: settings.PasswordResetEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, + LoginAgreementEnabled: settings.LoginAgreementEnabled, + LoginAgreementMode: settings.LoginAgreementMode, + LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt, + LoginAgreementRevision: settings.LoginAgreementRevision, + LoginAgreementDocuments: publicLoginAgreementDocumentsToDTO(settings.LoginAgreementDocuments), TurnstileEnabled: settings.TurnstileEnabled, TurnstileSiteKey: settings.TurnstileSiteKey, SiteName: settings.SiteName, @@ -63,6 +68,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, + GitHubOAuthEnabled: settings.GitHubOAuthEnabled, + GoogleOAuthEnabled: settings.GoogleOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, Version: h.version, @@ -77,5 +84,19 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { AvailableChannelsEnabled: settings.AvailableChannelsEnabled, AffiliateEnabled: settings.AffiliateEnabled, + + RiskControlEnabled: settings.RiskControlEnabled, }) } + +func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument { + result := make([]dto.LoginAgreementDocument, 0, len(items)) + for _, item := range items { + result = append(result, dto.LoginAgreementDocument{ + ID: item.ID, + Title: item.Title, + ContentMD: item.ContentMD, + }) + } + return result +} diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 192ca1f6..fb690858 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index ce161902..cb4ab0a4 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -36,6 +36,7 @@ func ProvideAdminHandlers( channelHandler *admin.ChannelHandler, channelMonitorHandler *admin.ChannelMonitorHandler, channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, + contentModerationHandler *admin.ContentModerationHandler, paymentHandler *admin.PaymentHandler, windsurfHandler *admin.WindsurfHandler, affiliateHandler *admin.AffiliateHandler, @@ -68,6 +69,7 @@ func ProvideAdminHandlers( Channel: channelHandler, ChannelMonitor: channelMonitorHandler, ChannelMonitorTemplate: channelMonitorTemplateHandler, + ContentModeration: contentModerationHandler, Payment: paymentHandler, Windsurf: windsurfHandler, Affiliate: affiliateHandler, @@ -180,6 +182,7 @@ var ProviderSet = wire.NewSet( admin.NewChannelHandler, admin.NewChannelMonitorHandler, admin.NewChannelMonitorRequestTemplateHandler, + admin.NewContentModerationHandler, admin.NewPaymentHandler, admin.NewAffiliateHandler, diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index d7aa1c3b..3c9f5089 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -317,6 +317,7 @@ const CLICurrentVersion = "2.1.92" // - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。 // - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。 // - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。 +// - 不默认加入 redact-thinking,避免上游抹除 thinking 内容;客户端显式传入时由合并逻辑保留。 func FullClaudeCodeMimicryBetas() []string { return []string{ BetaClaudeCode, @@ -324,7 +325,6 @@ func FullClaudeCodeMimicryBetas() []string { BetaInterleavedThinking, BetaPromptCachingScope, BetaEffort, - BetaRedactThinking, BetaContextManagement, BetaExtendedCacheTTL, } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 68895475..43b13937 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID, + apikey.FieldName, apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, diff --git a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go index aba62ead..4a462ab1 100644 --- a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go +++ b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go @@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S got, err := repo.GetByKeyForAuth(ctx, key.Key) require.NoError(t, err) + require.Equal(t, key.Name, got.Name) require.NotNil(t, got.Group) require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/repository/content_moderation_hash_cache.go b/backend/internal/repository/content_moderation_hash_cache.go new file mode 100644 index 00000000..782999e7 --- /dev/null +++ b/backend/internal/repository/content_moderation_hash_cache.go @@ -0,0 +1,71 @@ +package repository + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes" + +type contentModerationHashCache struct { + rdb *redis.Client +} + +func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache { + return &contentModerationHashCache{rdb: rdb} +} + +func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return nil + } + return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err() +} + +func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() +} + +func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() + if err != nil { + return false, err + } + return deleted > 0, nil +} + +func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() + if err != nil { + return 0, err + } + if deleted == 0 { + return 0, nil + } + if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil { + return 0, err + } + return deleted, nil +} + +func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() +} diff --git a/backend/internal/repository/content_moderation_repo.go b/backend/internal/repository/content_moderation_repo.go new file mode 100644 index 00000000..6ada004a --- /dev/null +++ b/backend/internal/repository/content_moderation_repo.go @@ -0,0 +1,274 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type contentModerationRepository struct { + db *sql.DB +} + +func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository { + return &contentModerationRepository{db: db} +} + +func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log == nil { + return nil + } + categoryScores, err := json.Marshal(log.CategoryScores) + if err != nil { + return fmt.Errorf("marshal moderation category scores: %w", err) + } + thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot) + if err != nil { + return fmt.Errorf("marshal moderation thresholds: %w", err) + } + var userID any + if log.UserID != nil { + userID = *log.UserID + } + var apiKeyID any + if log.APIKeyID != nil { + apiKeyID = *log.APIKeyID + } + var groupID any + if log.GroupID != nil { + groupID = *log.GroupID + } + var latency any + if log.UpstreamLatencyMS != nil { + latency = *log.UpstreamLatencyMS + } + err = r.db.QueryRowContext(ctx, ` +INSERT INTO content_moderation_logs ( + request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name, + endpoint, provider, model, mode, action, flagged, highest_category, highest_score, + category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error, + violation_count, auto_banned, email_sent, queue_delay_ms +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, $13, $14, $15, + $16::jsonb, $17::jsonb, $18, $19, $20, + $21, $22, $23, $24 +) RETURNING id, created_at`, + log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName, + log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore, + string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error, + log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS), + ).Scan(&log.ID, &log.CreatedAt) + if err != nil { + return fmt.Errorf("insert content moderation log: %w", err) + } + return nil +} + +func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + where, args := buildContentModerationLogWhere(filter) + whereSQL := "WHERE " + strings.Join(where, " AND ") + + var total int64 + if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count content moderation logs: %w", err) + } + + params := filter.Pagination + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 20 + } + if params.PageSize > 100 { + params.PageSize = 100 + } + queryArgs := append([]any{}, args...) + queryArgs = append(queryArgs, params.Limit(), params.Offset()) + rows, err := r.db.QueryContext(ctx, ` +SELECT + l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name, + l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score, + l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error, + l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at +FROM content_moderation_logs l +LEFT JOIN users u ON u.id = l.user_id `+whereSQL+` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)), + queryArgs..., + ) + if err != nil { + return nil, nil, fmt.Errorf("list content moderation logs: %w", err) + } + defer func() { _ = rows.Close() }() + + items := make([]service.ContentModerationLog, 0) + for rows.Next() { + var item service.ContentModerationLog + var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64 + var scoresRaw, thresholdsRaw []byte + if err := rows.Scan( + &item.ID, + &item.RequestID, + &userID, + &item.UserEmail, + &apiKeyID, + &item.APIKeyName, + &groupID, + &item.GroupName, + &item.Endpoint, + &item.Provider, + &item.Model, + &item.Mode, + &item.Action, + &item.Flagged, + &item.HighestCategory, + &item.HighestScore, + &scoresRaw, + &thresholdsRaw, + &item.InputExcerpt, + &latency, + &item.Error, + &item.ViolationCount, + &item.AutoBanned, + &item.EmailSent, + &item.UserStatus, + &queueDelay, + &item.CreatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan content moderation log: %w", err) + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if apiKeyID.Valid { + v := apiKeyID.Int64 + item.APIKeyID = &v + } + if groupID.Valid { + v := groupID.Int64 + item.GroupID = &v + } + if latency.Valid { + v := int(latency.Int64) + item.UpstreamLatencyMS = &v + } + if queueDelay.Valid { + v := int(queueDelay.Int64) + item.QueueDelayMS = &v + } + item.CategoryScores = map[string]float64{} + _ = json.Unmarshal(scoresRaw, &item.CategoryScores) + item.ThresholdSnapshot = map[string]float64{} + _ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot) + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err) + } + return items, paginationResultFromTotal(total, params), nil +} + +func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + if userID <= 0 { + return 0, nil + } + var count int + err := r.db.QueryRowContext(ctx, ` +WITH last_auto_ban AS ( + SELECT MAX(created_at) AS at + FROM content_moderation_logs + WHERE user_id = $1 AND auto_banned = TRUE +) +SELECT COUNT(*) +FROM content_moderation_logs +WHERE user_id = $1 + AND flagged = TRUE + AND created_at >= $2 + AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz) +`, userID, since).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count user content moderation flagged logs: %w", err) + } + return count, nil +} + +func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()} + if r == nil || r.db == nil { + return result, nil + } + hitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = TRUE AND created_at < $1 +`, hitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err) + } + result.DeletedHit, _ = hitExec.RowsAffected() + + nonHitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = FALSE AND created_at < $1 +`, nonHitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err) + } + result.DeletedNonHit, _ = nonHitExec.RowsAffected() + + result.FinishedAt = time.Now() + return result, nil +} + +func nullableIntPtr(value *int) any { + if value == nil { + return nil + } + return *value +} + +func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) { + where := []string{"l.id IS NOT NULL"} + args := make([]any, 0) + add := func(expr string, value any) { + args = append(args, value) + where = append(where, fmt.Sprintf(expr, len(args))) + } + switch strings.ToLower(strings.TrimSpace(filter.Result)) { + case "hit", "flagged": + where = append(where, "l.flagged = TRUE") + case "blocked", "block": + where = append(where, "l.action = 'block'") + case "pass", "allow": + where = append(where, "l.flagged = FALSE AND l.error = ''") + case "error": + where = append(where, "l.error <> ''") + } + if filter.GroupID != nil { + add("l.group_id = $%d", *filter.GroupID) + } + if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" { + add("l.endpoint = $%d", endpoint) + } + if search := strings.TrimSpace(filter.Search); search != "" { + like := "%" + search + "%" + args = append(args, like, like, like, like, like) + idx := len(args) - 4 + where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4)) + } + if filter.From != nil && !filter.From.IsZero() { + add("l.created_at >= $%d", *filter.From) + } + if filter.To != nil && !filter.To.IsZero() { + add("l.created_at <= $%d", *filter.To) + } + return where, args +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index d1f10cbd..1566756d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) { + if len(userIDs) == 0 { + return 0, nil + } + if value < 0 { + value = 0 + } + res, err := r.sql.ExecContext(ctx, + "UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL", + value, pq.Array(userIDs)) + if err != nil { + return 0, fmt.Errorf("batch set concurrency: %w", err) + } + affected, _ := res.RowsAffected() + return int(affected), nil +} + +func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) { + if len(userIDs) == 0 { + return 0, nil + } + res, err := r.sql.ExecContext(ctx, + "UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL", + delta, pq.Array(userIDs)) + if err != nil { + return 0, fmt.Errorf("batch add concurrency: %w", err) + } + affected, _ := res.RowsAffected() + return int(affected), nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index f07bbb33..3c0ee9cb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet( NewChannelRepository, NewChannelMonitorRepository, NewChannelMonitorRequestTemplateRepository, + NewContentModerationRepository, NewAffiliateRepository, // Cache implementations @@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet( NewRefreshTokenCache, NewErrorPassthroughCache, NewTLSFingerprintProfileCache, + NewContentModerationHashCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 0c7248d2..1a58c17f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -646,12 +646,21 @@ func TestAPIContracts(t *testing.T) { "registration_email_suffix_whitelist": [], "promo_code_enabled": true, "password_reset_enabled": false, - "frontend_url": "", - "totp_enabled": false, - "totp_encryption_key_configured": false, - "smtp_host": "smtp.example.com", - "smtp_port": 587, - "smtp_username": "user", + "frontend_url": "", + "totp_enabled": false, + "totp_encryption_key_configured": false, + "login_agreement_enabled": false, + "login_agreement_mode": "modal", + "login_agreement_updated_at": "2026-03-31", + "login_agreement_documents": [ + {"id": "terms", "title": "服务条款", "content_md": ""}, + {"id": "usage-policy", "title": "使用政策", "content_md": ""}, + {"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""}, + {"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""} + ], + "smtp_host": "smtp.example.com", + "smtp_port": 587, + "smtp_username": "user", "smtp_password_configured": true, "smtp_from_email": "no-reply@example.com", "smtp_from_name": "Sub2API", @@ -685,6 +694,16 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_userinfo_email_path": "", "oidc_connect_userinfo_id_path": "", "oidc_connect_userinfo_username_path": "", + "github_oauth_enabled": false, + "github_oauth_client_id": "", + "github_oauth_client_secret_configured": false, + "github_oauth_redirect_url": "", + "github_oauth_frontend_redirect_url": "/auth/oauth/callback", + "google_oauth_enabled": false, + "google_oauth_client_id": "", + "google_oauth_client_secret_configured": false, + "google_oauth_redirect_url": "", + "google_oauth_frontend_redirect_url": "/auth/oauth/callback", "ops_monitoring_enabled": false, "ops_realtime_monitoring_enabled": true, "ops_query_mode_default": "auto", @@ -700,6 +719,16 @@ func TestAPIContracts(t *testing.T) { "auth_source_default_email_subscriptions": [], "auth_source_default_email_grant_on_signup": false, "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_github_balance": 0, + "auth_source_default_github_concurrency": 5, + "auth_source_default_github_subscriptions": [], + "auth_source_default_github_grant_on_signup": false, + "auth_source_default_github_grant_on_first_bind": false, + "auth_source_default_google_balance": 0, + "auth_source_default_google_concurrency": 5, + "auth_source_default_google_subscriptions": [], + "auth_source_default_google_grant_on_signup": false, + "auth_source_default_google_grant_on_first_bind": false, "auth_source_default_linuxdo_balance": 0, "auth_source_default_linuxdo_concurrency": 5, "auth_source_default_linuxdo_subscriptions": [], @@ -792,6 +821,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": false, "wechat_connect_app_id": "", @@ -859,12 +889,21 @@ func TestAPIContracts(t *testing.T) { "promo_code_enabled": true, "password_reset_enabled": false, "frontend_url": "", - "invitation_code_enabled": false, - "totp_enabled": false, - "totp_encryption_key_configured": false, - "smtp_host": "", - "smtp_port": 587, - "smtp_username": "", + "invitation_code_enabled": false, + "totp_enabled": false, + "totp_encryption_key_configured": false, + "login_agreement_enabled": false, + "login_agreement_mode": "modal", + "login_agreement_updated_at": "2026-03-31", + "login_agreement_documents": [ + {"id": "terms", "title": "服务条款", "content_md": ""}, + {"id": "usage-policy", "title": "使用政策", "content_md": ""}, + {"id": "supported-regions", "title": "支持的国家和地区", "content_md": ""}, + {"id": "service-specific-terms", "title": "服务特定条款", "content_md": ""} + ], + "smtp_host": "", + "smtp_port": 587, + "smtp_username": "", "smtp_password_configured": false, "smtp_from_email": "", "smtp_from_name": "", @@ -898,6 +937,16 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_userinfo_email_path": "", "oidc_connect_userinfo_id_path": "", "oidc_connect_userinfo_username_path": "", + "github_oauth_enabled": false, + "github_oauth_client_id": "", + "github_oauth_client_secret_configured": false, + "github_oauth_redirect_url": "", + "github_oauth_frontend_redirect_url": "/auth/oauth/callback", + "google_oauth_enabled": false, + "google_oauth_client_id": "", + "google_oauth_client_secret_configured": false, + "google_oauth_redirect_url": "", + "google_oauth_frontend_redirect_url": "/auth/oauth/callback", "site_name": "Sub2API", "site_logo": "", "site_subtitle": "Subscription to API Conversion Platform", @@ -983,6 +1032,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": true, "wechat_connect_app_id": "wx-open-config", @@ -1005,6 +1055,16 @@ func TestAPIContracts(t *testing.T) { "auth_source_default_email_subscriptions": [], "auth_source_default_email_grant_on_signup": false, "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_github_balance": 0, + "auth_source_default_github_concurrency": 5, + "auth_source_default_github_subscriptions": [], + "auth_source_default_github_grant_on_signup": false, + "auth_source_default_github_grant_on_first_bind": false, + "auth_source_default_google_balance": 0, + "auth_source_default_google_concurrency": 5, + "auth_source_default_google_subscriptions": [], + "auth_source_default_google_grant_on_signup": false, + "auth_source_default_google_grant_on_first_bind": false, "auth_source_default_linuxdo_balance": 0, "auth_source_default_linuxdo_concurrency": 5, "auth_source_default_linuxdo_subscriptions": [], @@ -1123,7 +1183,7 @@ func newContractDeps(t *testing.T) *contractDeps { subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) + redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil) redeemHandler := handler.NewRedeemHandler(redeemService) settingRepo := newStubSettingRepo() @@ -1294,6 +1354,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i return errors.New("not implemented") } +func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { return false, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index dde92dfd..3fbbb716 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i panic("unexpected UpdateConcurrency call") } +func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { panic("unexpected ExistsByEmail call") } diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go index ae53037e..157f06b0 100644 --- a/backend/internal/server/middleware/backend_mode_guard.go +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -40,6 +40,8 @@ func backendModeAllowsAuthPath(path string) bool { "/auth/oauth/wechat/callback", "/auth/oauth/wechat/payment/callback", "/auth/oauth/oidc/callback", + "/auth/oauth/github/callback", + "/auth/oauth/google/callback", "/auth/oauth/linuxdo/complete-registration", "/auth/oauth/wechat/complete-registration", "/auth/oauth/oidc/complete-registration", diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go index bd77677b..de9c9ec9 100644 --- a/backend/internal/server/middleware/backend_mode_guard_test.go +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -246,6 +246,30 @@ func TestBackendModeAuthGuard(t *testing.T) { path: "/api/v1/auth/oauth/oidc/callback", wantStatus: http.StatusOK, }, + { + name: "enabled_blocks_github_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/github/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_github_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/github/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_google_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/google/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_google_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/google/callback", + wantStatus: http.StatusOK, + }, { name: "enabled_allows_oauth_pending_exchange", enabled: "true", diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 1c9970cc..9fdf3da5 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -120,4 +120,6 @@ func registerRoutes( routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster) routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService) + + handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 53047bdf..de1d06ad 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -95,11 +95,28 @@ func RegisterAdminRoutes( // 渠道监控 registerChannelMonitorRoutes(admin, h) + // 风控中心 + registerContentModerationRoutes(admin, h) + // 邀请返利(专属用户管理) registerAffiliateRoutes(admin, h) } } +func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + risk := admin.Group("/risk-control") + { + risk.GET("/config", h.Admin.ContentModeration.GetConfig) + risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig) + risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys) + risk.GET("/status", h.Admin.ContentModeration.GetStatus) + risk.GET("/logs", h.Admin.ContentModeration.ListLogs) + risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser) + risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash) + risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes) + } +} + func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { apiKeys := admin.Group("/api-keys") { @@ -234,6 +251,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) + users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -270,6 +288,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel) + accounts.POST("/import/codex-session", h.Admin.Account.ImportCodexSession) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.PUT("/:id", h.Admin.Account.Update) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 642a2103..54d40e92 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -63,6 +63,22 @@ func RegisterAuthRoutes( FailureMode: middleware.RateLimitFailClose, }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/github/start", h.Auth.GitHubOAuthStart) + auth.GET("/oauth/github/callback", h.Auth.GitHubOAuthCallback) + auth.POST("/oauth/github/complete-registration", + rateLimiter.LimitWithOptions("oauth-github-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteGitHubOAuthRegistration, + ) + auth.GET("/oauth/google/start", h.Auth.GoogleOAuthStart) + auth.GET("/oauth/google/callback", h.Auth.GoogleOAuthCallback) + auth.POST("/oauth/google/complete-registration", + rateLimiter.LimitWithOptions("oauth-google-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteGoogleOAuthRegistration, + ) auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) { query := c.Request.URL.Query() query.Set("intent", "bind_current_user") diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1bf44218..24afbd68 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -33,6 +33,7 @@ type AdminService interface { UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) DeleteUser(ctx context.Context, id int64) error UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) + BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) @@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return nil } +func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) { + cleaned := make([]int64, 0, len(userIDs)) + for _, uid := range userIDs { + if uid > 0 { + cleaned = append(cleaned, uid) + } + } + if len(cleaned) == 0 { + return 0, nil + } + + var affected int + var err error + switch mode { + case "set": + affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value) + case "add": + affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value) + default: + return 0, errors.New("invalid mode: must be 'set' or 'add'") + } + if err != nil { + return 0, err + } + + if s.authCacheInvalidator != nil { + for _, uid := range cleaned { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid) + } + } + return affected, nil +} + func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index fcde5cbf..3b3dbc21 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } + +func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fe9e7701..a9492a1d 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i panic("unexpected UpdateConcurrency call") } +func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) { if s.existsErr != nil { return false, s.existsErr diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go index 2232c9c3..c791b747 100644 --- a/backend/internal/service/admin_service_email_identity_sync_test.go +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) return 0, nil } +func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 4432ad7d..3553a18a 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` UserID int64 `json:"user_id"` GroupID *int64 `json:"group_id,omitempty"` + Name string `json:"name"` Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist,omitempty"` IPBlacklist []string `json:"ip_blacklist,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 0f9d4214..877888b1 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls +const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs type apiKeyAuthCacheConfig struct { l1Size int @@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) APIKeyID: apiKey.ID, UserID: apiKey.UserID, GroupID: apiKey.GroupID, + Name: apiKey.Name, Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, @@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho UserID: snapshot.UserID, GroupID: snapshot.GroupID, Key: key, + Name: snapshot.Name, Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 8cb1b8c4..eaac9a1c 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t UserID: 2, GroupID: &groupID, Key: "k-roundtrip", + Name: "Audit Key", Status: StatusActive, User: &User{ ID: 2, @@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) + require.Equal(t, apiKey.Name, roundTrip.Name) require.NotNil(t, roundTrip.Group) require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/service/auth_email_oauth_auto.go b/backend/internal/service/auth_email_oauth_auto.go new file mode 100644 index 00000000..56fd4004 --- /dev/null +++ b/backend/internal/service/auth_email_oauth_auto.go @@ -0,0 +1,274 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +type EmailOAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Email string + EmailVerified bool + Username string + DisplayName string + AvatarURL string + UpstreamMetadata map[string]any +} + +func (s *AuthService) LoginOrRegisterVerifiedEmailOAuth(ctx context.Context, input EmailOAuthIdentityInput) (*TokenPair, *User, error) { + return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, "", "") +} + +func (s *AuthService) LoginOrRegisterVerifiedEmailOAuthWithInvitation( + ctx context.Context, + input EmailOAuthIdentityInput, + invitationCode string, + affiliateCode string, +) (*TokenPair, *User, error) { + return s.loginOrRegisterVerifiedEmailOAuth(ctx, input, invitationCode, affiliateCode) +} + +func (s *AuthService) loginOrRegisterVerifiedEmailOAuth( + ctx context.Context, + input EmailOAuthIdentityInput, + invitationCode string, + affiliateCode string, +) (*TokenPair, *User, error) { + if s == nil || s.userRepo == nil || s.entClient == nil { + return nil, nil, ErrServiceUnavailable + } + + providerType := normalizeOAuthSignupSource(input.ProviderType) + if providerType != "github" && providerType != "google" { + return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid") + } + providerKey := strings.TrimSpace(input.ProviderKey) + if providerKey == "" { + providerKey = providerType + } + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerSubject == "" { + return nil, nil, infraerrors.BadRequest("OAUTH_SUBJECT_MISSING", "oauth subject is missing") + } + if !input.EmailVerified { + return nil, nil, infraerrors.Forbidden("OAUTH_EMAIL_NOT_VERIFIED", "oauth email is not verified") + } + + email := strings.TrimSpace(strings.ToLower(input.Email)) + if email == "" || len(email) > 255 { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + + identityUser, err := s.findEmailOAuthIdentityOwner(ctx, providerType, providerKey, providerSubject) + if err != nil { + return nil, nil, err + } + if identityUser != nil && !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) { + return nil, nil, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email") + } + + user := identityUser + created := false + if user == nil { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + user, err = s.createEmailOAuthUser(ctx, email, input.Username, providerType, invitationCode, affiliateCode) + if err != nil { + return nil, nil, err + } + created = true + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error during %s oauth login: %v", providerType, err) + return nil, nil, ErrServiceUnavailable + } + } + } + + if !user.IsActive() { + return nil, nil, ErrUserNotActive + } + if err := s.ensureEmailOAuthIdentity(ctx, user.ID, EmailOAuthIdentityInput{ + ProviderType: providerType, + ProviderKey: providerKey, + ProviderSubject: providerSubject, + Email: email, + EmailVerified: input.EmailVerified, + Username: input.Username, + DisplayName: input.DisplayName, + AvatarURL: input.AvatarURL, + UpstreamMetadata: input.UpstreamMetadata, + }); err != nil { + return nil, nil, err + } + + if user.Username == "" && strings.TrimSpace(input.Username) != "" { + user.Username = strings.TrimSpace(input.Username) + if err := s.userRepo.Update(ctx, user); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after %s oauth login: %v", providerType, err) + } + } + if !created { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, providerType); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply %s first bind defaults: %v", providerType, err) + } + } + s.RecordSuccessfulLogin(ctx, user.ID) + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, providerType, invitationCode, affiliateCode string) (*User, error) { + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, ErrRegDisabled + } + invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + if errors.Is(err, ErrInvitationCodeRequired) { + return nil, ErrOAuthInvitationRequired + } + return nil, err + } + + randomPassword, err := randomHexString(32) + if err != nil { + return nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return nil, fmt.Errorf("hash password: %w", err) + } + grantPlan := s.resolveSignupGrantPlan(ctx, providerType) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } + user := &User{ + Email: email, + Username: strings.TrimSpace(username), + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, + Status: StatusActive, + SignupSource: providerType, + } + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + existing, loadErr := s.userRepo.GetByEmail(ctx, email) + if loadErr != nil { + return nil, ErrServiceUnavailable + } + return existing, nil + } + return nil, ErrServiceUnavailable + } + s.postAuthUserBootstrap(ctx, user, providerType, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) + if invitationRedeemCode != nil { + if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { + _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, invitationCode) + return nil, ErrInvitationCodeInvalid + } + } + return user, nil +} + +func (s *AuthService) findEmailOAuthIdentityOwner(ctx context.Context, providerType, providerKey, providerSubject string) (*User, error) { + identity, err := s.entClient.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + user, err := s.userRepo.GetByID(ctx, identity.UserID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, nil + } + return nil, ErrServiceUnavailable + } + return user, nil +} + +func (s *AuthService) ensureEmailOAuthIdentity(ctx context.Context, userID int64, input EmailOAuthIdentityInput) error { + metadata := map[string]any{ + "email": strings.TrimSpace(strings.ToLower(input.Email)), + "email_verified": input.EmailVerified, + } + for key, value := range input.UpstreamMetadata { + metadata[key] = value + } + if strings.TrimSpace(input.Username) != "" { + metadata["username"] = strings.TrimSpace(input.Username) + } + if strings.TrimSpace(input.DisplayName) != "" { + metadata["display_name"] = strings.TrimSpace(input.DisplayName) + } + if strings.TrimSpace(input.AvatarURL) != "" { + metadata["avatar_url"] = strings.TrimSpace(input.AvatarURL) + } + + providerType := normalizeOAuthSignupSource(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + identity, err := s.entClient.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if identity != nil { + if identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + _, err = s.entClient.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(metadata). + Save(ctx) + return err + } + _, err = s.entClient.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata). + Save(ctx) + return err +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index 9815f31b..e3c8298c 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -10,6 +10,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) func normalizeOAuthSignupSource(signupSource string) string { @@ -17,7 +18,7 @@ func normalizeOAuthSignupSource(signupSource string) string { switch signupSource { case "", "email": return "email" - case "linuxdo", "wechat", "oidc": + case "linuxdo", "wechat", "oidc", "github", "google": return signupSource default: return "email" @@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount( return tokenPair, user, nil } +// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth +// provider that has already returned a verified email address. +func (s *AuthService) RegisterVerifiedOAuthEmailAccount( + ctx context.Context, + email string, + password string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || len(email) > 255 { + return nil, nil, ErrEmailVerifyRequired + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, nil, ErrEmailVerifyRequired + } + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if strings.TrimSpace(password) == "" { + return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required") + } + if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil { + return nil, nil, err + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = normalizeOAuthSignupSource(signupSource) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, + Status: StatusActive, + SignupSource: signupSource, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "") + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + // FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap // only after the pending OAuth flow has fully reached its last reversible step. func (s *AuthService) FinalizeOAuthEmailAccount( diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index 21d9d6e9..cd76c6b7 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -229,6 +229,67 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes require.Equal(t, "oidc", userRepo.created[0].SignupSource) } +func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) { + tests := []struct { + name string + email string + signupSource string + want string + }{ + { + name: "github", + email: "github@example.com", + signupSource: " GitHub ", + want: "github", + }, + { + name: "google", + email: "google@example.com", + signupSource: " Google ", + want: "google", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userRepo := &userRepoStub{nextID: 43} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + tt.email, + "secret-123", + "246810", + "", + tt.signupSource, + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, tt.want, userRepo.created[0].SignupSource) + }) + } +} + func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) { userRepo := &userRepoStub{nextID: 43} emailCache := &emailCacheStub{ @@ -256,7 +317,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing "secret-123", "246810", "", - "github", + "unknown-provider", ) require.NoError(t, err) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index b1adf071..e01e8217 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -775,6 +775,10 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource return defaults.OIDC, true case "wechat": return defaults.WeChat, true + case "github": + return defaults.GitHub, true + case "google": + return defaults.Google, true default: return ProviderDefaultGrantSettings{}, false } diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index e2392e4b..6845c1f4 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) ( return ok, nil } +func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } diff --git a/backend/internal/service/codex_image_generation_bridge.go b/backend/internal/service/codex_image_generation_bridge.go new file mode 100644 index 00000000..c7a894a7 --- /dev/null +++ b/backend/internal/service/codex_image_generation_bridge.go @@ -0,0 +1,64 @@ +package service + +import "strings" + +const featureKeyCodexImageGenerationBridge = "codex_image_generation_bridge" + +func boolOverridePtr(v bool) *bool { + return &v +} + +func boolOverrideFromMap(values map[string]any, keys ...string) *bool { + if values == nil { + return nil + } + for _, key := range keys { + if v, ok := values[key].(bool); ok { + return boolOverridePtr(v) + } + } + return nil +} + +func platformBoolOverride(values map[string]any, key string, platform string) *bool { + if values == nil { + return nil + } + if v, ok := values[key].(bool); ok { + return boolOverridePtr(v) + } + raw, ok := values[key].(map[string]any) + if !ok { + return nil + } + platform = strings.TrimSpace(platform) + if platform == "" { + return nil + } + if v, ok := raw[platform].(bool); ok { + return boolOverridePtr(v) + } + return nil +} + +// CodexImageGenerationBridgeOverride returns the channel-level override for Codex +// image_generation bridge injection. Nil means follow the global/account policy. +func (c *Channel) CodexImageGenerationBridgeOverride(platform string) *bool { + if c == nil { + return nil + } + return platformBoolOverride(c.FeaturesConfig, featureKeyCodexImageGenerationBridge, platform) +} + +// CodexImageGenerationBridgeOverride returns the account-level override for Codex +// image_generation bridge injection. Nil means follow the channel/global policy. +func (a *Account) CodexImageGenerationBridgeOverride() *bool { + if a == nil || a.Platform != PlatformOpenAI || a.Extra == nil { + return nil + } + if override := boolOverrideFromMap(a.Extra, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled"); override != nil { + return override + } + openaiConfig, _ := a.Extra[PlatformOpenAI].(map[string]any) + return boolOverrideFromMap(openaiConfig, featureKeyCodexImageGenerationBridge, "codex_image_generation_bridge_enabled") +} diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go new file mode 100644 index 00000000..144222c2 --- /dev/null +++ b/backend/internal/service/content_moderation.go @@ -0,0 +1,2048 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + ContentModerationModeOff = "off" + ContentModerationModeObserve = "observe" + ContentModerationModePreBlock = "pre_block" + + contentModerationAPIKeysModeAppend = "append" + contentModerationAPIKeysModeReplace = "replace" + + ContentModerationActionAllow = "allow" + ContentModerationActionBlock = "block" + ContentModerationActionHashBlock = "hash_block" + ContentModerationActionError = "error" + + ContentModerationProtocolAnthropicMessages = "anthropic_messages" + ContentModerationProtocolOpenAIResponses = "openai_responses" + ContentModerationProtocolOpenAIChat = "openai_chat_completions" + ContentModerationProtocolGemini = "gemini" + ContentModerationProtocolOpenAIImages = "openai_images" + + defaultContentModerationBaseURL = "https://api.openai.com" + defaultContentModerationModel = "omni-moderation-latest" + defaultContentModerationTimeoutMS = 3000 + maxContentModerationTimeoutMS = 30000 + maxModerationInputRunes = 12000 + maxModerationExcerptRunes = 240 + + defaultContentModerationWorkerCount = 4 + maxContentModerationWorkerCount = 32 + defaultContentModerationQueueSize = 32768 + maxContentModerationQueueSize = 100000 + defaultContentModerationBanThreshold = 10 + defaultContentModerationViolationWindowHours = 720 + defaultContentModerationBlockHTTPStatus = http.StatusForbidden + defaultContentModerationBlockMessage = "内容审计命中风险规则,请调整输入后重试" + defaultContentModerationRetryCount = 2 + maxContentModerationRetryCount = 5 + defaultContentModerationHitRetentionDays = 180 + defaultContentModerationNonHitRetentionDays = 3 + maxContentModerationRetentionDays = 3650 + maxContentModerationNonHitRetentionDays = 3 + contentModerationKeyRateLimitFreezeDuration = time.Minute + contentModerationKeyAuthFreezeDuration = 10 * time.Minute + contentModerationKeyHTTPErrorFreezeDuration = 10 * time.Second + maxContentModerationInputImages = 1 + maxContentModerationTestImages = maxContentModerationInputImages + maxContentModerationTestImageBytes = 8 * 1024 * 1024 + maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024 + + contentModerationCleanupInterval = 24 * time.Hour + contentModerationCleanupTimeout = 30 * time.Minute + contentModerationCleanupDelay = 5 * time.Minute +) + +var contentModerationCategoryOrder = []string{ + "harassment", + "harassment/threatening", + "hate", + "hate/threatening", + "illicit", + "illicit/violent", + "self-harm", + "self-harm/intent", + "self-harm/instructions", + "sexual", + "sexual/minors", + "violence", + "violence/graphic", +} + +func ContentModerationDefaultThresholds() map[string]float64 { + return map[string]float64{ + "harassment": 0.98, + "harassment/threatening": 0.90, + "hate": 0.65, + "hate/threatening": 0.65, + "illicit": 0.95, + "illicit/violent": 0.95, + "self-harm": 0.65, + "self-harm/intent": 0.85, + "self-harm/instructions": 0.65, + "sexual": 0.65, + "sexual/minors": 0.65, + "violence": 0.95, + "violence/graphic": 0.95, + } +} + +func ContentModerationCategories() []string { + out := make([]string, len(contentModerationCategoryOrder)) + copy(out, contentModerationCategoryOrder) + return out +} + +type ContentModerationConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKey string `json:"api_key,omitempty"` + APIKeys []string `json:"api_keys,omitempty"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationConfigView struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKeyConfigured bool `json:"api_key_configured"` + APIKeyMasked string `json:"api_key_masked"` + APIKeyCount int `json:"api_key_count"` + APIKeyMasks []string `json:"api_key_masks"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationAPIKeyStatus struct { + Index int `json:"index"` + KeyHash string `json:"key_hash"` + Masked string `json:"masked"` + Status string `json:"status"` + FailureCount int `json:"failure_count"` + SuccessCount int64 `json:"success_count"` + LastError string `json:"last_error"` + LastCheckedAt *time.Time `json:"last_checked_at,omitempty"` + FrozenUntil *time.Time `json:"frozen_until,omitempty"` + LastLatencyMS int `json:"last_latency_ms"` + LastHTTPStatus int `json:"last_http_status"` + LastTested bool `json:"last_tested"` + Configured bool `json:"configured"` +} + +type TestContentModerationAPIKeysInput struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type TestContentModerationAPIKeysResult struct { + Items []ContentModerationAPIKeyStatus `json:"items"` + AuditResult *ContentModerationTestAuditResult `json:"audit_result,omitempty"` + ImageCount int `json:"image_count"` +} + +type ContentModerationTestAuditResult struct { + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CompositeScore float64 `json:"composite_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Thresholds map[string]float64 `json:"thresholds"` +} + +type UpdateContentModerationConfigInput struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationCheckInput struct { + RequestID string + UserID int64 + UserEmail string + APIKeyID int64 + APIKeyName string + GroupID *int64 + GroupName string + Endpoint string + Provider string + Model string + Protocol string + Body []byte +} + +type ContentModerationInput struct { + Text string + Images []string +} + +func (in *ContentModerationInput) Normalize() { + if in == nil { + return + } + in.Text = trimRunes(normalizeContentModerationText(in.Text), maxModerationInputRunes) + in.Images = normalizeModerationImages(in.Images) +} + +func (in ContentModerationInput) IsEmpty() bool { + return strings.TrimSpace(in.Text) == "" && len(in.Images) == 0 +} + +func (in ContentModerationInput) ModerationInput() any { + images := limitContentModerationImages(in.Images) + if len(images) == 0 { + return in.Text + } + parts := make([]moderationAPIInputPart, 0, len(images)+1) + if strings.TrimSpace(in.Text) != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: in.Text}) + } + for _, image := range images { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts +} + +func (in ContentModerationInput) ExcerptText() string { + return in.Text +} + +func (in ContentModerationInput) Hash() string { + h := sha256.New() + _, _ = h.Write([]byte("text:")) + _, _ = h.Write([]byte(in.Text)) + for _, image := range in.Images { + imageHash := sha256.Sum256([]byte(image)) + _, _ = h.Write([]byte("\nimage:")) + _, _ = h.Write([]byte(hex.EncodeToString(imageHash[:]))) + } + return hex.EncodeToString(h.Sum(nil)) +} + +type ContentModerationDecision struct { + Allowed bool `json:"allowed"` + Blocked bool `json:"blocked"` + Flagged bool `json:"flagged"` + Message string `json:"message"` + StatusCode int `json:"status_code"` + InputHash string `json:"input_hash,omitempty"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Action string `json:"action"` +} + +type ContentModerationLog struct { + ID int64 `json:"id"` + RequestID string `json:"request_id"` + UserID *int64 `json:"user_id,omitempty"` + UserEmail string `json:"user_email"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + APIKeyName string `json:"api_key_name"` + GroupID *int64 `json:"group_id,omitempty"` + GroupName string `json:"group_name"` + Endpoint string `json:"endpoint"` + Provider string `json:"provider"` + Model string `json:"model"` + Mode string `json:"mode"` + Action string `json:"action"` + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + ThresholdSnapshot map[string]float64 `json:"threshold_snapshot"` + InputExcerpt string `json:"input_excerpt"` + UpstreamLatencyMS *int `json:"upstream_latency_ms,omitempty"` + Error string `json:"error"` + ViolationCount int `json:"violation_count"` + AutoBanned bool `json:"auto_banned"` + EmailSent bool `json:"email_sent"` + UserStatus string `json:"user_status"` + QueueDelayMS *int `json:"queue_delay_ms,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ContentModerationLogFilter struct { + Pagination pagination.PaginationParams + Result string + GroupID *int64 + Endpoint string + Search string + From *time.Time + To *time.Time +} + +type ContentModerationCleanupResult struct { + DeletedHit int64 `json:"deleted_hit"` + DeletedNonHit int64 `json:"deleted_non_hit"` + FinishedAt time.Time `json:"finished_at"` +} + +type ContentModerationRuntimeStatus struct { + Enabled bool `json:"enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` + Mode string `json:"mode"` + WorkerCount int `json:"worker_count"` + MaxWorkers int `json:"max_workers"` + ActiveWorkers int `json:"active_workers"` + IdleWorkers int `json:"idle_workers"` + QueueSize int `json:"queue_size"` + QueueLength int `json:"queue_length"` + QueueUsagePercent float64 `json:"queue_usage_percent"` + Enqueued int64 `json:"enqueued"` + Dropped int64 `json:"dropped"` + Processed int64 `json:"processed"` + Errors int64 `json:"errors"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + FlaggedHashCount int64 `json:"flagged_hash_count"` + LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"` + LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"` + LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"` +} + +type ContentModerationUnbanUserResult struct { + UserID int64 `json:"user_id"` + Status string `json:"status"` +} + +type ContentModerationDeleteHashResult struct { + InputHash string `json:"input_hash"` + Deleted bool `json:"deleted"` +} + +type ContentModerationClearHashesResult struct { + Deleted int64 `json:"deleted"` +} + +type ContentModerationRepository interface { + CreateLog(ctx context.Context, log *ContentModerationLog) error + ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) + CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) + CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) +} + +type ContentModerationHashCache interface { + RecordFlaggedInputHash(ctx context.Context, inputHash string) error + HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + ClearFlaggedInputHashes(ctx context.Context) (int64, error) + CountFlaggedInputHashes(ctx context.Context) (int64, error) +} + +type ContentModerationService struct { + settingRepo SettingRepository + repo ContentModerationRepository + hashCache ContentModerationHashCache + groupRepo GroupRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + emailService *EmailService + httpClient *http.Client + asyncQueue chan contentModerationTask + workerCount int + apiKeyCursor atomic.Uint64 + asyncActive atomic.Int64 + asyncEnqueued atomic.Int64 + asyncDropped atomic.Int64 + asyncProcessed atomic.Int64 + asyncErrors atomic.Int64 + lastCleanupUnix atomic.Int64 + lastCleanupDeletedHit atomic.Int64 + lastCleanupDeletedNonHit atomic.Int64 + keyHealthMu sync.Mutex + keyHealth map[string]*contentModerationKeyHealth +} + +type contentModerationTask struct { + input ContentModerationCheckInput + content ContentModerationInput + inputHash string + enqueuedAt time.Time +} + +type contentModerationKeyHealth struct { + Hash string + Masked string + FailureCount int + SuccessCount int64 + LastError string + LastCheckedAt time.Time + FrozenUntil time.Time + LastLatencyMS int + LastHTTPStatus int + LastTested bool +} + +func NewContentModerationService( + settingRepo SettingRepository, + repo ContentModerationRepository, + hashCache ContentModerationHashCache, + groupRepo GroupRepository, + userRepo UserRepository, + authCacheInvalidator APIKeyAuthCacheInvalidator, + emailService *EmailService, +) *ContentModerationService { + svc := &ContentModerationService{ + settingRepo: settingRepo, + repo: repo, + hashCache: hashCache, + groupRepo: groupRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, + emailService: emailService, + httpClient: &http.Client{}, + workerCount: maxContentModerationWorkerCount, + asyncQueue: make(chan contentModerationTask, maxContentModerationQueueSize), + keyHealth: make(map[string]*contentModerationKeyHealth), + } + if settingRepo != nil && repo != nil { + for i := 0; i < svc.workerCount; i++ { + go svc.worker(i) + } + go svc.cleanupWorker() + } + return svc +} + +func (s *ContentModerationService) GetConfig(ctx context.Context) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) UpdateConfig(ctx context.Context, input UpdateContentModerationConfigInput) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + if input.Enabled != nil { + cfg.Enabled = *input.Enabled + } + if input.Mode != nil { + cfg.Mode = strings.TrimSpace(*input.Mode) + } + if input.BaseURL != nil { + cfg.BaseURL = strings.TrimSpace(*input.BaseURL) + } + if input.Model != nil { + cfg.Model = strings.TrimSpace(*input.Model) + } + if input.TimeoutMS != nil { + cfg.TimeoutMS = *input.TimeoutMS + } + if input.SampleRate != nil { + cfg.SampleRate = *input.SampleRate + } + if input.WorkerCount != nil { + cfg.WorkerCount = *input.WorkerCount + } + if input.QueueSize != nil { + cfg.QueueSize = *input.QueueSize + } + if input.BlockStatus != nil { + cfg.BlockStatus = *input.BlockStatus + } + if input.BlockMessage != nil { + cfg.BlockMessage = strings.TrimSpace(*input.BlockMessage) + } + if input.EmailOnHit != nil { + cfg.EmailOnHit = *input.EmailOnHit + } + if input.AutoBanEnabled != nil { + cfg.AutoBanEnabled = *input.AutoBanEnabled + } + if input.BanThreshold != nil { + cfg.BanThreshold = *input.BanThreshold + } + if input.ViolationWindowHours != nil { + cfg.ViolationWindowHours = *input.ViolationWindowHours + } + if input.RetryCount != nil { + cfg.RetryCount = *input.RetryCount + } + if input.HitRetentionDays != nil { + cfg.HitRetentionDays = *input.HitRetentionDays + } + if input.NonHitRetentionDays != nil { + cfg.NonHitRetentionDays = *input.NonHitRetentionDays + } + if input.PreHashCheckEnabled != nil { + cfg.PreHashCheckEnabled = *input.PreHashCheckEnabled + } + if input.AllGroups != nil { + cfg.AllGroups = *input.AllGroups + } + if input.GroupIDs != nil { + cfg.GroupIDs = normalizeInt64IDs(*input.GroupIDs) + } + if input.RecordNonHits != nil { + cfg.RecordNonHits = *input.RecordNonHits + } + if input.ClearAPIKey { + cfg.APIKey = "" + cfg.APIKeys = []string{} + } else { + apiKeysMode := normalizeContentModerationAPIKeysMode(input.APIKeysMode) + if input.DeleteAPIKeyHashes != nil && apiKeysMode != contentModerationAPIKeysModeReplace { + cfg.APIKeys = deleteModerationAPIKeysByHash(cfg.apiKeys(), *input.DeleteAPIKeyHashes) + cfg.APIKey = "" + } + if input.APIKeys != nil { + if apiKeysMode == contentModerationAPIKeysModeReplace { + cfg.APIKeys = normalizeModerationAPIKeys(*input.APIKeys) + } else { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.apiKeys(), *input.APIKeys...)) + } + cfg.APIKey = "" + } + if input.APIKey != nil && strings.TrimSpace(*input.APIKey) != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, *input.APIKey)) + cfg.APIKey = "" + } + } + if err := s.validateConfig(ctx, cfg); err != nil { + return nil, err + } + cfg.normalize() + raw, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal content moderation config: %w", err) + } + if err := s.settingRepo.Set(ctx, SettingKeyContentModerationConfig, string(raw)); err != nil { + return nil, fmt.Errorf("save content moderation config: %w", err) + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) TestAPIKeys(ctx context.Context, input TestContentModerationAPIKeysInput) (*TestContentModerationAPIKeysResult, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + keys := normalizeModerationAPIKeys(input.APIKeys) + configured := false + if len(keys) == 0 { + keys = cfg.apiKeys() + configured = true + } + if strings.TrimSpace(input.BaseURL) != "" { + cfg.BaseURL = input.BaseURL + } + if strings.TrimSpace(input.Model) != "" { + cfg.Model = input.Model + } + if input.TimeoutMS > 0 { + cfg.TimeoutMS = input.TimeoutMS + } + cfg.normalize() + testInput, imageCount, err := buildModerationTestInput(input.Prompt, input.Images) + if err != nil { + return nil, err + } + auditOnly := contentModerationTestHasAuditInput(input.Prompt, input.Images) + if configured && auditOnly { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + return &TestContentModerationAPIKeysResult{ + Items: s.apiKeyStatuses(keys), + ImageCount: imageCount, + }, nil + } + keys = []string{key} + } + if len(keys) == 0 { + return &TestContentModerationAPIKeysResult{Items: []ContentModerationAPIKeyStatus{}, ImageCount: imageCount}, nil + } + items := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + var auditResult *ContentModerationTestAuditResult + for idx, key := range keys { + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, testInput, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + keyHash := moderationAPIKeyHash(key) + if err != nil { + s.markAPIKeyError(key, err.Error(), latency, httpStatus) + } else { + s.markAPIKeySuccess(key, latency, httpStatus) + if auditResult == nil { + auditResult = buildContentModerationTestAuditResult(result, cfg.Thresholds) + } + } + status := s.apiKeyStatusForHash(idx, keyHash, maskSecretTail(key), configured) + status.LastTested = true + items = append(items, status) + } + return &TestContentModerationAPIKeysResult{Items: items, AuditResult: auditResult, ImageCount: imageCount}, nil +} + +func (s *ContentModerationService) Check(ctx context.Context, input ContentModerationCheckInput) (*ContentModerationDecision, error) { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + if s == nil || s.settingRepo == nil || s.repo == nil { + slog.Info("content_moderation.skip_unavailable", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !s.isRiskControlEnabled(ctx) { + slog.Info("content_moderation.skip_feature_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.skip_config_load_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "error", err) + return allow, nil + } + inScope := cfg.includesGroup(input.GroupID) + slog.Info("content_moderation.config_loaded", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "provider", input.Provider, + "protocol", input.Protocol, + "model", input.Model, + "enabled", cfg.Enabled, + "mode", cfg.Mode, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs, + "in_scope", inScope, + "sample_rate", cfg.SampleRate, + "api_key_count", len(cfg.apiKeys()), + "pre_hash_check_enabled", cfg.PreHashCheckEnabled, + "record_non_hits", cfg.RecordNonHits) + if !cfg.Enabled { + slog.Info("content_moderation.skip_config_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeOff { + slog.Info("content_moderation.skip_mode_off", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !inScope { + slog.Info("content_moderation.skip_group_out_of_scope", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs) + return allow, nil + } + content := ExtractContentModerationInput(input.Protocol, input.Body) + if content.IsEmpty() { + slog.Info("content_moderation.skip_empty_input", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "body_bytes", len(input.Body)) + return allow, nil + } + content.Normalize() + slog.Info("content_moderation.input_extracted", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "text_runes", len([]rune(content.Text)), + "image_count", len(content.Images)) + hashText := content.Hash() + if cfg.PreHashCheckEnabled && s.hashCache != nil { + matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText) + if err != nil { + slog.Warn("content_moderation.hash_check_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + if matched { + slog.Info("content_moderation.hash_block", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "input_hash", hashText) + message := cfg.BlockMessage + if message != "" { + message = fmt.Sprintf("%s(hash: %s)", message, hashText) + } + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: message, + StatusCode: cfg.BlockStatus, + InputHash: hashText, + Action: ContentModerationActionHashBlock, + }, nil + } + } + if !cfg.shouldSample(hashText) { + slog.Info("content_moderation.skip_sample_rate", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "sample_rate", cfg.SampleRate) + return allow, nil + } + if len(cfg.apiKeys()) == 0 { + slog.Warn("content_moderation.skip_no_audit_api_keys", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeObserve { + slog.Info("content_moderation.enqueue_observe", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "queue_len", len(s.asyncQueue)) + s.enqueueAsync(input, cfg, content, hashText) + return allow, nil + } + + return s.checkSync(ctx, input, cfg, content, hashText, nil, true), nil +} + +func (s *ContentModerationService) checkSync(ctx context.Context, input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string, queueDelay *int, allowBlock bool) *ContentModerationDecision { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + start := time.Now() + result, err := s.callModeration(ctx, cfg, content.ModerationInput()) + latency := int(time.Since(start).Milliseconds()) + if err != nil { + slog.Warn("content_moderation.audit_api_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "queue_delay_ms", queueDelay, + "latency_ms", latency, + "error", err) + if queueDelay != nil { + s.asyncErrors.Add(1) + } + if cfg.RecordNonHits { + log := s.buildLog(input, cfg, ContentModerationActionError, false, "", 0, nil, content.ExcerptText(), &latency, queueDelay, err.Error()) + _ = s.repo.CreateLog(ctx, log) + } + return allow + } + + flagged, highestCategory, highestScore := evaluateModerationScores(result.CategoryScores, cfg.Thresholds) + action := ContentModerationActionAllow + blocked := false + if allowBlock && flagged && cfg.Mode == ContentModerationModePreBlock { + action = ContentModerationActionBlock + blocked = true + } + slog.Info("content_moderation.audit_result", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "flagged", flagged, + "blocked", blocked, + "action", action, + "highest_category", highestCategory, + "highest_score", highestScore, + "latency_ms", latency, + "queue_delay_ms", queueDelay) + if flagged || cfg.RecordNonHits { + log := s.buildLog(input, cfg, action, flagged, highestCategory, highestScore, result.CategoryScores, content.ExcerptText(), &latency, queueDelay, "") + if flagged && s.hashCache != nil { + if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil { + slog.Warn("content_moderation.record_hash_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + } + s.applyFlaggedSideEffects(ctx, cfg, log) + _ = s.repo.CreateLog(ctx, log) + } + if blocked { + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: cfg.BlockMessage, + StatusCode: cfg.BlockStatus, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } + } + return &ContentModerationDecision{ + Allowed: true, + Flagged: flagged, + Message: "", + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } +} + +func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string) { + if s == nil || s.asyncQueue == nil { + return + } + queueSize := defaultContentModerationQueueSize + if cfg != nil && cfg.QueueSize > 0 { + queueSize = cfg.QueueSize + } + if len(s.asyncQueue) >= queueSize { + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint, "queue_size", queueSize) + s.asyncDropped.Add(1) + return + } + task := contentModerationTask{ + input: input, + content: content, + inputHash: hashText, + enqueuedAt: time.Now(), + } + select { + case s.asyncQueue <- task: + s.asyncEnqueued.Add(1) + default: + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint) + s.asyncDropped.Add(1) + } +} + +func (s *ContentModerationService) worker(id int) { + for { + ctx, cancel := context.WithTimeout(context.Background(), maxContentModerationTimeoutMS*time.Millisecond+10*time.Second) + cfg, err := s.loadConfig(ctx) + if err != nil || !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 || id >= cfg.WorkerCount { + cancel() + time.Sleep(time.Second) + continue + } + task, ok := s.dequeueAsyncTask(ctx, time.Second) + if !ok { + cancel() + continue + } + func() { + defer cancel() + defer func() { + if r := recover(); r != nil { + slog.Error("content_moderation.worker_panic", "worker_id", id, "recover", r) + } + }() + if !cfg.includesGroup(task.input.GroupID) { + return + } + s.asyncActive.Add(1) + defer s.asyncActive.Add(-1) + queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) + _ = s.checkSync(ctx, task.input, cfg, task.content, task.inputHash, &queueDelay, false) + s.asyncProcessed.Add(1) + }() + } +} + +func (s *ContentModerationService) dequeueAsyncTask(ctx context.Context, idleWait time.Duration) (contentModerationTask, bool) { + var zero contentModerationTask + if s == nil || s.asyncQueue == nil { + return zero, false + } + if idleWait <= 0 { + idleWait = time.Second + } + timer := time.NewTimer(idleWait) + defer timer.Stop() + select { + case task, ok := <-s.asyncQueue: + return task, ok + case <-ctx.Done(): + return zero, false + case <-timer.C: + return zero, false + } +} + +func (s *ContentModerationService) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) { + if filter.Pagination.Page <= 0 { + filter.Pagination.Page = 1 + } + if filter.Pagination.PageSize <= 0 { + filter.Pagination.PageSize = 20 + } + if filter.Pagination.PageSize > 100 { + filter.Pagination.PageSize = 100 + } + if filter.Pagination.SortOrder == "" { + filter.Pagination.SortOrder = pagination.SortOrderDesc + } + return s.repo.ListLogs(ctx, filter) +} + +func (s *ContentModerationService) UnbanUser(ctx context.Context, userID int64) (*ContentModerationUnbanUserResult, error) { + if s == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_USER_REPOSITORY_UNAVAILABLE", "用户仓储不可用") + } + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER_ID", "用户 ID 无效") + } + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, infraerrors.NotFound("USER_NOT_FOUND", "用户不存在") + } + return nil, fmt.Errorf("get content moderation unban user: %w", err) + } + if user.Status != StatusActive { + user.Status = StatusActive + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, fmt.Errorf("update content moderation unban user: %w", err) + } + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return &ContentModerationUnbanUserResult{ + UserID: userID, + Status: StatusActive, + }, nil +} + +func (s *ContentModerationService) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (*ContentModerationDeleteHashResult, error) { + inputHash = normalizeContentModerationHash(inputHash) + if inputHash == "" { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_HASH", "风险输入哈希无效") + } + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.DeleteFlaggedInputHash(ctx, inputHash) + if err != nil { + return nil, fmt.Errorf("delete content moderation flagged hash: %w", err) + } + return &ContentModerationDeleteHashResult{ + InputHash: inputHash, + Deleted: deleted, + }, nil +} + +func (s *ContentModerationService) ClearFlaggedInputHashes(ctx context.Context) (*ContentModerationClearHashesResult, error) { + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.ClearFlaggedInputHashes(ctx) + if err != nil { + return nil, fmt.Errorf("clear content moderation flagged hashes: %w", err) + } + return &ContentModerationClearHashesResult{Deleted: deleted}, nil +} + +func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModerationRuntimeStatus, error) { + if s == nil { + return &ContentModerationRuntimeStatus{}, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + riskEnabled := s.isRiskControlEnabled(ctx) + active := int(s.asyncActive.Load()) + if active < 0 { + active = 0 + } + if active > cfg.WorkerCount { + active = cfg.WorkerCount + } + queueLength := 0 + if s.asyncQueue != nil { + queueLength = len(s.asyncQueue) + } + queueUsage := 0.0 + if cfg.QueueSize > 0 { + queueUsage = float64(queueLength) * 100 / float64(cfg.QueueSize) + } + var flaggedHashCount int64 + if s.hashCache != nil { + if n, err := s.hashCache.CountFlaggedInputHashes(ctx); err == nil { + flaggedHashCount = n + } else { + slog.Warn("content_moderation.hash_count_failed", "error", err) + } + } + var lastCleanupAt *time.Time + if unix := s.lastCleanupUnix.Load(); unix > 0 { + t := time.Unix(unix, 0) + lastCleanupAt = &t + } + return &ContentModerationRuntimeStatus{ + Enabled: cfg.Enabled, + RiskControlEnabled: riskEnabled, + Mode: cfg.Mode, + WorkerCount: cfg.WorkerCount, + MaxWorkers: maxContentModerationWorkerCount, + ActiveWorkers: active, + IdleWorkers: cfg.WorkerCount - active, + QueueSize: cfg.QueueSize, + QueueLength: queueLength, + QueueUsagePercent: queueUsage, + Enqueued: s.asyncEnqueued.Load(), + Dropped: s.asyncDropped.Load(), + Processed: s.asyncProcessed.Load(), + Errors: s.asyncErrors.Load(), + APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()), + FlaggedHashCount: flaggedHashCount, + LastCleanupAt: lastCleanupAt, + LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(), + LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(), + }, nil +} + +func (s *ContentModerationService) cleanupWorker() { + timer := time.NewTimer(contentModerationCleanupDelay) + defer timer.Stop() + for { + <-timer.C + s.runCleanupOnce() + timer.Reset(contentModerationCleanupInterval) + } +} + +func (s *ContentModerationService) runCleanupOnce() { + if s == nil || s.repo == nil || s.settingRepo == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), contentModerationCleanupTimeout) + defer cancel() + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.cleanup_load_config_failed", "error", err) + return + } + now := time.Now() + hitBefore := now.AddDate(0, 0, -cfg.HitRetentionDays) + nonHitBefore := now.AddDate(0, 0, -cfg.NonHitRetentionDays) + result, err := s.repo.CleanupExpiredLogs(ctx, hitBefore, nonHitBefore) + if err != nil { + slog.Warn("content_moderation.cleanup_failed", "error", err) + return + } + if result == nil { + return + } + s.lastCleanupUnix.Store(result.FinishedAt.Unix()) + s.lastCleanupDeletedHit.Store(result.DeletedHit) + s.lastCleanupDeletedNonHit.Store(result.DeletedNonHit) +} + +func (s *ContentModerationService) loadConfig(ctx context.Context) (*ContentModerationConfig, error) { + cfg := defaultContentModerationConfig() + raw, err := s.settingRepo.GetValue(ctx, SettingKeyContentModerationConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + cfg.normalize() + return cfg, nil + } + return nil, fmt.Errorf("get content moderation config: %w", err) + } + if strings.TrimSpace(raw) == "" { + cfg.normalize() + return cfg, nil + } + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不是有效 JSON") + } + cfg.normalize() + return cfg, nil +} + +func (s *ContentModerationService) isRiskControlEnabled(ctx context.Context) bool { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyRiskControlEnabled) + if err != nil { + return false + } + return raw == "true" +} + +func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *ContentModerationConfig) error { + if cfg == nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不能为空") + } + cfg.normalize() + switch cfg.Mode { + case ContentModerationModeOff, ContentModerationModeObserve, ContentModerationModePreBlock: + default: + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODE", "内容审计模式无效") + } + if _, err := url.ParseRequestURI(cfg.BaseURL); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BASE_URL", "OpenAI Base URL 无效") + } + if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间") + } + if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil { + for _, groupID := range cfg.GroupIDs { + if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_GROUP", fmt.Sprintf("审计分组不存在: %d", groupID)) + } + } + } + return nil +} + +func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any) (*moderationAPIResult, error) { + attempts := cfg.RetryCount + 1 + if attempts <= 0 { + attempts = 1 + } + if attempts > maxContentModerationRetryCount+1 { + attempts = maxContentModerationRetryCount + 1 + } + var lastErr error + for attempt := 0; attempt < attempts; attempt++ { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + lastErr = errors.New("no moderation api key available") + break + } + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, input, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + if err == nil { + s.markAPIKeySuccess(key, latency, httpStatus) + return result, nil + } + s.markAPIKeyError(key, err.Error(), latency, httpStatus) + lastErr = err + if httpStatus == http.StatusBadRequest { + break + } + if attempt == attempts-1 { + break + } + wait := time.Duration(100*(attempt+1)) * time.Millisecond + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } + return nil, lastErr +} + +func (s *ContentModerationService) callModerationOnceWithInput(ctx context.Context, cfg *ContentModerationConfig, apiKey string, input any, httpStatus *int) (*moderationAPIResult, error) { + base := strings.TrimRight(cfg.BaseURL, "/") + endpoint, err := url.JoinPath(base, "/v1/moderations") + if err != nil { + return nil, err + } + payload := moderationAPIRequest{ + Model: cfg.Model, + Input: input, + } + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + timeout := time.Duration(cfg.TimeoutMS) * time.Millisecond + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, endpoint, bytes.NewReader(raw)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Content-Type", "application/json") + + client := s.httpClient + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if httpStatus != nil { + *httpStatus = resp.StatusCode + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("moderation api status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var out moderationAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + if len(out.Results) == 0 { + return nil, errors.New("moderation api returned empty results") + } + return &out.Results[0], nil +} + +func (s *ContentModerationService) buildLog(input ContentModerationCheckInput, cfg *ContentModerationConfig, action string, flagged bool, highestCategory string, highestScore float64, scores map[string]float64, text string, latency *int, queueDelay *int, errText string) *ContentModerationLog { + var userID *int64 + if input.UserID > 0 { + userID = &input.UserID + } + var apiKeyID *int64 + if input.APIKeyID > 0 { + apiKeyID = &input.APIKeyID + } + return &ContentModerationLog{ + RequestID: input.RequestID, + UserID: userID, + UserEmail: input.UserEmail, + APIKeyID: apiKeyID, + APIKeyName: input.APIKeyName, + GroupID: cloneInt64Ptr(input.GroupID), + GroupName: input.GroupName, + Endpoint: input.Endpoint, + Provider: input.Provider, + Model: input.Model, + Mode: cfg.Mode, + Action: action, + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: cloneFloatMap(scores), + ThresholdSnapshot: cloneFloatMap(cfg.Thresholds), + InputExcerpt: trimRunes(redactContentModerationSecrets(text), maxModerationExcerptRunes), + UpstreamLatencyMS: latency, + QueueDelayMS: queueDelay, + Error: errText, + } +} + +func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) { + if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 { + return + } + count := 1 + if s.repo != nil && cfg.ViolationWindowHours > 0 { + since := time.Now().Add(-time.Duration(cfg.ViolationWindowHours) * time.Hour) + if n, err := s.repo.CountFlaggedByUserSince(ctx, *log.UserID, since); err == nil { + count = n + 1 + } + } + log.ViolationCount = count + autoBanJustApplied := false + if cfg.AutoBanEnabled && cfg.BanThreshold > 0 && count >= cfg.BanThreshold && s.userRepo != nil { + user, err := s.userRepo.GetByID(ctx, *log.UserID) + if err != nil { + slog.Warn("content_moderation.ban_get_user_failed", "user_id", *log.UserID, "error", err) + return + } + if user.Status != StatusDisabled { + user.Status = StatusDisabled + if err := s.userRepo.Update(ctx, user); err != nil { + slog.Warn("content_moderation.ban_update_user_failed", "user_id", *log.UserID, "error", err) + return + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, *log.UserID) + } + autoBanJustApplied = true + } + log.AutoBanned = true + } + + if s.emailService == nil || strings.TrimSpace(log.UserEmail) == "" { + return + } + emailSent := false + if cfg.EmailOnHit { + if err := s.sendViolationEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + if autoBanJustApplied { + if err := s.sendAccountDisabledEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.ban_email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + log.EmailSent = emailSent +} + +func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName)) + body := buildContentModerationViolationEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName)) + body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) siteName(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return "Sub2API" + } + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || strings.TrimSpace(name) == "" { + return "Sub2API" + } + return strings.TrimSpace(name) +} + +func defaultContentModerationConfig() *ContentModerationConfig { + return &ContentModerationConfig{ + Enabled: false, + Mode: ContentModerationModePreBlock, + BaseURL: defaultContentModerationBaseURL, + Model: defaultContentModerationModel, + TimeoutMS: defaultContentModerationTimeoutMS, + SampleRate: 100, + AllGroups: true, + GroupIDs: []int64{}, + RecordNonHits: false, + Thresholds: ContentModerationDefaultThresholds(), + WorkerCount: defaultContentModerationWorkerCount, + QueueSize: defaultContentModerationQueueSize, + BlockStatus: defaultContentModerationBlockHTTPStatus, + BlockMessage: defaultContentModerationBlockMessage, + EmailOnHit: true, + AutoBanEnabled: true, + BanThreshold: defaultContentModerationBanThreshold, + ViolationWindowHours: defaultContentModerationViolationWindowHours, + RetryCount: defaultContentModerationRetryCount, + HitRetentionDays: defaultContentModerationHitRetentionDays, + NonHitRetentionDays: defaultContentModerationNonHitRetentionDays, + PreHashCheckEnabled: false, + } +} + +func (cfg *ContentModerationConfig) normalize() { + if cfg.APIKey != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, cfg.APIKey)) + cfg.APIKey = "" + } else { + cfg.APIKeys = normalizeModerationAPIKeys(cfg.APIKeys) + } + if cfg.Mode == "" { + cfg.Mode = ContentModerationModePreBlock + } + if cfg.BaseURL == "" { + cfg.BaseURL = defaultContentModerationBaseURL + } + cfg.BaseURL = strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if cfg.Model == "" { + cfg.Model = defaultContentModerationModel + } + cfg.Model = strings.TrimSpace(cfg.Model) + if cfg.TimeoutMS <= 0 { + cfg.TimeoutMS = defaultContentModerationTimeoutMS + } + if cfg.TimeoutMS > maxContentModerationTimeoutMS { + cfg.TimeoutMS = maxContentModerationTimeoutMS + } + if cfg.SampleRate < 0 { + cfg.SampleRate = 0 + } + if cfg.SampleRate > 100 { + cfg.SampleRate = 100 + } + if cfg.WorkerCount <= 0 { + cfg.WorkerCount = defaultContentModerationWorkerCount + } + if cfg.WorkerCount > maxContentModerationWorkerCount { + cfg.WorkerCount = maxContentModerationWorkerCount + } + if cfg.QueueSize <= 0 { + cfg.QueueSize = defaultContentModerationQueueSize + } + if cfg.QueueSize > maxContentModerationQueueSize { + cfg.QueueSize = maxContentModerationQueueSize + } + if strings.TrimSpace(cfg.BlockMessage) == "" { + cfg.BlockMessage = defaultContentModerationBlockMessage + } + cfg.BlockMessage = strings.TrimSpace(cfg.BlockMessage) + if cfg.BlockStatus <= 0 { + cfg.BlockStatus = defaultContentModerationBlockHTTPStatus + } + if cfg.BanThreshold <= 0 { + cfg.BanThreshold = defaultContentModerationBanThreshold + } + if cfg.ViolationWindowHours <= 0 { + cfg.ViolationWindowHours = defaultContentModerationViolationWindowHours + } + if cfg.RetryCount < 0 { + cfg.RetryCount = 0 + } + if cfg.RetryCount > maxContentModerationRetryCount { + cfg.RetryCount = maxContentModerationRetryCount + } + if cfg.HitRetentionDays <= 0 { + cfg.HitRetentionDays = defaultContentModerationHitRetentionDays + } + if cfg.HitRetentionDays > maxContentModerationRetentionDays { + cfg.HitRetentionDays = maxContentModerationRetentionDays + } + if cfg.NonHitRetentionDays <= 0 { + cfg.NonHitRetentionDays = defaultContentModerationNonHitRetentionDays + } + if cfg.NonHitRetentionDays > maxContentModerationNonHitRetentionDays { + cfg.NonHitRetentionDays = maxContentModerationNonHitRetentionDays + } + cfg.GroupIDs = normalizeInt64IDs(cfg.GroupIDs) + cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds) +} + +func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { + if cfg.AllGroups { + return true + } + if groupID == nil { + return false + } + for _, id := range cfg.GroupIDs { + if id == *groupID { + return true + } + } + return false +} + +func contentModerationLogGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func (cfg *ContentModerationConfig) shouldSample(hashText string) bool { + if cfg.SampleRate >= 100 { + return true + } + if cfg.SampleRate <= 0 { + return false + } + raw, err := hex.DecodeString(hashText) + if err != nil || len(raw) < 2 { + return true + } + return int(binary.BigEndian.Uint16(raw[:2])%100) < cfg.SampleRate +} + +func (cfg *ContentModerationConfig) apiKeys() []string { + if cfg == nil { + return nil + } + return normalizeModerationAPIKeys(cfg.APIKeys) +} + +func (s *ContentModerationService) nextUsableAPIKey(cfg *ContentModerationConfig) (string, bool) { + keys := cfg.apiKeys() + if len(keys) == 0 { + return "", false + } + now := time.Now() + for i := 0; i < len(keys); i++ { + idx := int(s.apiKeyCursor.Add(1)-1) % len(keys) + key := keys[idx] + if !s.isAPIKeyFrozen(key, now) { + return key, true + } + } + return "", false +} + +func (s *ContentModerationService) isAPIKeyFrozen(key string, now time.Time) bool { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return false + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + return state != nil && state.FrozenUntil.After(now) +} + +func (s *ContentModerationService) markAPIKeySuccess(key string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.FailureCount = 0 + state.SuccessCount++ + state.LastError = "" + state.LastCheckedAt = time.Now() + state.FrozenUntil = time.Time{} + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true +} + +func (s *ContentModerationService) markAPIKeyError(key string, errText string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + if contentModerationFreezeDurationForHTTPStatus(httpStatus) > 0 { + state.FailureCount++ + } + state.LastError = trimRunes(errText, 180) + state.LastCheckedAt = time.Now() + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true + if freezeDuration := contentModerationFreezeDurationForHTTPStatus(httpStatus); freezeDuration > 0 { + state.FrozenUntil = time.Now().Add(freezeDuration) + } +} + +func contentModerationFreezeDurationForHTTPStatus(httpStatus int) time.Duration { + switch httpStatus { + case 0, http.StatusBadRequest: + return 0 + case http.StatusUnauthorized, http.StatusForbidden: + return contentModerationKeyAuthFreezeDuration + case http.StatusTooManyRequests, 529: + return contentModerationKeyRateLimitFreezeDuration + default: + return contentModerationKeyHTTPErrorFreezeDuration + } +} + +func (s *ContentModerationService) ensureAPIKeyHealthLocked(hash string, masked string) *contentModerationKeyHealth { + if s.keyHealth == nil { + s.keyHealth = make(map[string]*contentModerationKeyHealth) + } + state := s.keyHealth[hash] + if state == nil { + state = &contentModerationKeyHealth{Hash: hash} + s.keyHealth[hash] = state + } + if strings.TrimSpace(masked) != "" { + state.Masked = masked + } + return state +} + +func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *ContentModerationConfigView { + keys := cfg.apiKeys() + masks := make([]string, 0, len(keys)) + for _, key := range keys { + masks = append(masks, maskSecretTail(key)) + } + apiKeyMasked := "" + if len(masks) > 0 { + apiKeyMasked = masks[0] + } + return &ContentModerationConfigView{ + Enabled: cfg.Enabled, + Mode: cfg.Mode, + BaseURL: cfg.BaseURL, + Model: cfg.Model, + APIKeyConfigured: len(keys) > 0, + APIKeyMasked: apiKeyMasked, + APIKeyCount: len(keys), + APIKeyMasks: masks, + APIKeyStatuses: s.apiKeyStatuses(keys), + TimeoutMS: cfg.TimeoutMS, + SampleRate: cfg.SampleRate, + AllGroups: cfg.AllGroups, + GroupIDs: append([]int64(nil), cfg.GroupIDs...), + RecordNonHits: cfg.RecordNonHits, + WorkerCount: cfg.WorkerCount, + QueueSize: cfg.QueueSize, + BlockStatus: cfg.BlockStatus, + BlockMessage: cfg.BlockMessage, + EmailOnHit: cfg.EmailOnHit, + AutoBanEnabled: cfg.AutoBanEnabled, + BanThreshold: cfg.BanThreshold, + ViolationWindowHours: cfg.ViolationWindowHours, + RetryCount: cfg.RetryCount, + HitRetentionDays: cfg.HitRetentionDays, + NonHitRetentionDays: cfg.NonHitRetentionDays, + PreHashCheckEnabled: cfg.PreHashCheckEnabled, + } +} + +func (s *ContentModerationService) apiKeyStatuses(keys []string) []ContentModerationAPIKeyStatus { + out := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + for idx, key := range keys { + out = append(out, s.apiKeyStatusForHash(idx, moderationAPIKeyHash(key), maskSecretTail(key), true)) + } + return out +} + +func (s *ContentModerationService) apiKeyStatusForHash(index int, hash string, masked string, configured bool) ContentModerationAPIKeyStatus { + status := ContentModerationAPIKeyStatus{ + Index: index, + KeyHash: hash, + Masked: masked, + Status: "unknown", + Configured: configured, + } + if hash == "" || s == nil { + return status + } + now := time.Now() + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + if state == nil { + return status + } + status.FailureCount = state.FailureCount + status.SuccessCount = state.SuccessCount + status.LastError = state.LastError + status.LastLatencyMS = state.LastLatencyMS + status.LastHTTPStatus = state.LastHTTPStatus + status.LastTested = state.LastTested + if !state.LastCheckedAt.IsZero() { + t := state.LastCheckedAt + status.LastCheckedAt = &t + } + if state.FrozenUntil.After(now) { + t := state.FrozenUntil + status.FrozenUntil = &t + status.Status = "frozen" + return status + } + if state.LastError != "" { + status.Status = "error" + return status + } + if state.SuccessCount > 0 || state.LastTested { + status.Status = "ok" + } + return status +} + +func moderationAPIKeyHash(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func buildModerationTestInput(prompt string, images []string) (any, int, error) { + prompt = trimRunes(normalizeContentModerationText(prompt), maxModerationInputRunes) + normalizedImages := make([]string, 0, len(images)) + for _, image := range images { + image = strings.TrimSpace(image) + if image == "" { + continue + } + if len(normalizedImages) >= maxContentModerationTestImages { + return nil, 0, infraerrors.BadRequest("TOO_MANY_MODERATION_TEST_IMAGES", fmt.Sprintf("最多上传 %d 张测试图片", maxContentModerationTestImages)) + } + if err := validateModerationTestImageDataURL(image); err != nil { + return nil, 0, err + } + normalizedImages = append(normalizedImages, image) + } + if prompt == "" && len(normalizedImages) == 0 { + return "hello", 0, nil + } + if len(normalizedImages) == 0 { + return prompt, 0, nil + } + parts := make([]moderationAPIInputPart, 0, len(normalizedImages)+1) + if prompt != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: prompt}) + } + for _, image := range normalizedImages { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts, len(normalizedImages), nil +} + +func contentModerationTestHasAuditInput(prompt string, images []string) bool { + if normalizeContentModerationText(prompt) != "" { + return true + } + for _, image := range images { + if strings.TrimSpace(image) != "" { + return true + } + } + return false +} + +func validateModerationTestImageDataURL(value string) error { + if len(value) > maxContentModerationTestImageDataURLBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + if !strings.HasPrefix(value, "data:image/") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 data:image/* base64") + } + parts := strings.SplitN(value, ",", 2) + if len(parts) != 2 || !strings.Contains(parts[0], ";base64") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 base64 data URL") + } + raw, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片 base64 无效") + } + if len(raw) > maxContentModerationTestImageBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + return nil +} + +func buildContentModerationTestAuditResult(result *moderationAPIResult, thresholds map[string]float64) *ContentModerationTestAuditResult { + if result == nil { + return nil + } + scores := make(map[string]float64, len(result.CategoryScores)) + for category, score := range result.CategoryScores { + scores[category] = score + } + thresholdSnapshot := mergeContentModerationThresholds(ContentModerationDefaultThresholds(), thresholds) + flagged, highestCategory, highestScore := evaluateModerationScores(scores, thresholdSnapshot) + compositeScore := highestScore + return &ContentModerationTestAuditResult{ + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CompositeScore: compositeScore, + CategoryScores: scores, + Thresholds: thresholdSnapshot, + } +} + +type moderationAPIRequest struct { + Model string `json:"model"` + Input any `json:"input"` +} + +type moderationAPIInputPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *moderationAPIImageURLRef `json:"image_url,omitempty"` +} + +type moderationAPIImageURLRef struct { + URL string `json:"url"` +} + +type moderationAPIResponse struct { + Results []moderationAPIResult `json:"results"` +} + +type moderationAPIResult struct { + Flagged bool `json:"flagged"` + CategoryScores map[string]float64 `json:"category_scores"` +} + +func evaluateModerationScores(scores map[string]float64, thresholds map[string]float64) (bool, string, float64) { + flagged := false + highestCategory := "" + highestScore := 0.0 + for _, category := range contentModerationCategoryOrder { + score := scores[category] + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + if score >= thresholds[category] { + flagged = true + } + } + for category, score := range scores { + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + } + return flagged, highestCategory, highestScore +} + +func mergeContentModerationThresholds(base map[string]float64, override map[string]float64) map[string]float64 { + out := cloneFloatMap(base) + if out == nil { + out = map[string]float64{} + } + for _, category := range contentModerationCategoryOrder { + if v, ok := override[category]; ok { + if v < 0 { + v = 0 + } + if v > 1 { + v = 1 + } + out[category] = v + } + } + return out +} + +func normalizeInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return []int64{} + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} + +func normalizeModerationAPIKeys(keys []string) []string { + if len(keys) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(keys)) + out := make([]string, 0, len(keys)) + for _, key := range keys { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, key) + } + return out +} + +func deleteModerationAPIKeysByHash(keys []string, hashes []string) []string { + keys = normalizeModerationAPIKeys(keys) + deleteHashes := make(map[string]struct{}, len(hashes)) + for _, hash := range hashes { + hash = normalizeContentModerationHash(hash) + if hash != "" { + deleteHashes[hash] = struct{}{} + } + } + if len(deleteHashes) == 0 { + return keys + } + out := make([]string, 0, len(keys)) + for _, key := range keys { + if _, ok := deleteHashes[moderationAPIKeyHash(key)]; ok { + continue + } + out = append(out, key) + } + return out +} + +func normalizeContentModerationAPIKeysMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case contentModerationAPIKeysModeReplace: + return contentModerationAPIKeysModeReplace + default: + return contentModerationAPIKeysModeAppend + } +} + +func normalizeContentModerationHash(inputHash string) string { + inputHash = strings.ToLower(strings.TrimSpace(inputHash)) + if len(inputHash) != sha256.Size*2 { + return "" + } + if _, err := hex.DecodeString(inputHash); err != nil { + return "" + } + return inputHash +} + +func cloneFloatMap(in map[string]float64) map[string]float64 { + if in == nil { + return map[string]float64{} + } + out := make(map[string]float64, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneInt64Ptr(in *int64) *int64 { + if in == nil { + return nil + } + v := *in + return &v +} + +func trimRunes(text string, max int) string { + if max <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= max { + return text + } + return string(runes[:max]) +} + +func maskSecretTail(secret string) string { + secret = strings.TrimSpace(secret) + if secret == "" { + return "" + } + if len(secret) <= 4 { + return "****" + } + return strings.Repeat("*", 8) + secret[len(secret)-4:] +} diff --git a/backend/internal/service/content_moderation_email.go b/backend/internal/service/content_moderation_email.go new file mode 100644 index 00000000..e462ff88 --- /dev/null +++ b/backend/internal/service/content_moderation_email.go @@ -0,0 +1,117 @@ +package service + +import ( + "fmt" + "html" + "strings" + "time" +) + +func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string { + if log == nil { + return "" + } + userName := strings.TrimSpace(log.UserEmail) + if userName == "" && log.UserID != nil { + userName = fmt.Sprintf("UID %d", *log.UserID) + } + threshold := cfg.BanThreshold + if threshold <= 0 { + threshold = defaultContentModerationBanThreshold + } + statusBlock := "" + if log.AutoBanned { + statusBlock = `
账户当前处于封禁状态,所有 API 请求将被拒绝
` + } + return fmt.Sprintf(` + + +
+
+
+
Risk Control / 风控提醒
+

账户触发内容审计规则

+

尊敬的用户 %s,您的 API 请求在内容审计中触发平台风控策略。详情如下。

+
+

触发详情

+ + + + + + +
触发时间%s
触发来源内容审核
所属分组%s
命中类别%s / %.3f
累计触发次数%d 次(阈值 %d)
+
+ %s +

此邮件由 %s 自动发送,请勿回复。

+
+
+ +`, + html.EscapeString(userName), + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")), + html.EscapeString(defaultContentModerationString(log.GroupName, "-")), + html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")), + log.HighestScore, + log.ViolationCount, + threshold, + statusBlock, + html.EscapeString(siteName), + ) +} + +func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string { + if log == nil { + return "" + } + userName := strings.TrimSpace(log.UserEmail) + if userName == "" && log.UserID != nil { + userName = fmt.Sprintf("UID %d", *log.UserID) + } + threshold := cfg.BanThreshold + if threshold <= 0 { + threshold = defaultContentModerationBanThreshold + } + return fmt.Sprintf(` + + +
+
+
+
Risk Control / 账户封禁
+

账户已被自动禁用

+

尊敬的用户 %s,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。

+
+

封禁详情

+ + + + + + +
封禁时间%s
触发来源内容审核
所属分组%s
命中类别%s / %.3f
累计触发次数%d 次(阈值 %d)
+
+
账户当前处于封禁状态,所有 API 请求将被拒绝
+

如需申诉或恢复账号,请联系平台管理员处理。

+

此邮件由 %s 自动发送,请勿回复。

+
+
+ +`, + html.EscapeString(userName), + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")), + html.EscapeString(defaultContentModerationString(log.GroupName, "-")), + html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")), + log.HighestScore, + log.ViolationCount, + threshold, + html.EscapeString(siteName), + ) +} + +func defaultContentModerationString(value string, fallback string) string { + if strings.TrimSpace(value) == "" { + return fallback + } + return strings.TrimSpace(value) +} diff --git a/backend/internal/service/content_moderation_input.go b/backend/internal/service/content_moderation_input.go new file mode 100644 index 00000000..67df397d --- /dev/null +++ b/backend/internal/service/content_moderation_input.go @@ -0,0 +1,320 @@ +package service + +import ( + "crypto/rand" + "fmt" + "math/big" + "strings" + + "github.com/tidwall/gjson" +) + +func ExtractContentModerationText(protocol string, body []byte) string { + return ExtractContentModerationInput(protocol, body).Text +} + +func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput { + if len(body) == 0 || !gjson.ValidBytes(body) { + return ContentModerationInput{} + } + var parts []string + var images []string + switch protocol { + case ContentModerationProtocolAnthropicMessages: + collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images) + case ContentModerationProtocolOpenAIChat: + collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images) + case ContentModerationProtocolOpenAIResponses: + collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images) + case ContentModerationProtocolGemini: + collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images) + case ContentModerationProtocolOpenAIImages: + addModerationText(&parts, gjson.GetBytes(body, "prompt").String()) + collectContentValue(gjson.GetBytes(body, "images"), &parts, &images) + default: + collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images) + collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images) + collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images) + } + out := ContentModerationInput{ + Text: normalizeContentModerationText(strings.Join(parts, "\n")), + Images: normalizeModerationImages(images), + } + out.Normalize() + return out +} + +func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) { + if !messages.IsArray() { + return + } + var lastParts []string + var lastImages []string + messages.ForEach(func(_, msg gjson.Result) bool { + if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role { + var candidate []string + var candidateImages []string + collectContentValue(msg.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) { + if !messages.IsArray() { + return + } + var lastParts []string + var lastImages []string + messages.ForEach(func(_, msg gjson.Result) bool { + if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" { + var candidate []string + var candidateImages []string + collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) { + switch { + case !value.Exists(): + return + case value.Type == gjson.String: + if !isAnthropicSystemReminderText(value.String()) { + addModerationText(parts, value.String()) + } + case value.IsArray(): + value.ForEach(func(_, item gjson.Result) bool { + collectAnthropicUserContentValue(item, parts, images) + return true + }) + case value.IsObject(): + typ := strings.ToLower(strings.TrimSpace(value.Get("type").String())) + switch typ { + case "", "text", "input_text", "message": + if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) { + addModerationText(parts, value.Get("text").String()) + } + if value.Get("content").Exists() { + collectAnthropicUserContentValue(value.Get("content"), parts, images) + } + case "image_url", "input_image", "image": + collectContentValue(value, parts, images) + } + } +} + +func isAnthropicSystemReminderText(text string) bool { + return strings.HasPrefix(strings.TrimSpace(text), "") +} + +func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) { + switch { + case !input.Exists(): + return + case input.Type == gjson.String: + addModerationText(parts, input.String()) + case input.IsArray(): + var last gjson.Result + input.ForEach(func(_, item gjson.Result) bool { + if isResponsesUserTextItem(item) { + last = item + } + return true + }) + if last.Exists() { + collectContentValue(last.Get("content"), parts, images) + if last.Get("type").String() == "input_text" || last.Get("text").Exists() { + collectContentValue(last, parts, images) + } + } + case input.IsObject(): + if isResponsesUserTextItem(input) { + collectContentValue(input.Get("content"), parts, images) + if input.Get("type").String() == "input_text" || input.Get("text").Exists() { + collectContentValue(input, parts, images) + } + } + } +} + +func isResponsesUserTextItem(item gjson.Result) bool { + role := strings.ToLower(strings.TrimSpace(item.Get("role").String())) + if role == "user" { + return responseItemHasModerationText(item) + } + if role != "" { + return false + } + return responseItemHasModerationText(item) +} + +func responseItemHasModerationText(item gjson.Result) bool { + var parts []string + var images []string + collectContentValue(item.Get("content"), &parts, &images) + if item.Get("type").String() == "input_text" || item.Get("text").Exists() { + collectContentValue(item, &parts, &images) + } + return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0 +} + +func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) { + if !contents.IsArray() { + return + } + var lastParts []string + var lastImages []string + contents.ForEach(func(_, content gjson.Result) bool { + role := strings.ToLower(strings.TrimSpace(content.Get("role").String())) + if role == "" || role == "user" { + var candidate []string + var candidateImages []string + if arr := content.Get("parts"); arr.IsArray() { + arr.ForEach(func(_, part gjson.Result) bool { + addModerationText(&candidate, part.Get("text").String()) + addGeminiModerationImage(&candidateImages, part) + return true + }) + } + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectContentValue(value gjson.Result, parts *[]string, images *[]string) { + switch { + case !value.Exists(): + return + case value.Type == gjson.String: + addModerationText(parts, value.String()) + case value.IsArray(): + value.ForEach(func(_, item gjson.Result) bool { + collectContentValue(item, parts, images) + return true + }) + case value.IsObject(): + typ := strings.ToLower(strings.TrimSpace(value.Get("type").String())) + addModerationImage(images, value.Get("image_url.url").String()) + addModerationImage(images, value.Get("image_url").String()) + addModerationImage(images, value.Get("url").String()) + addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String()) + addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String()) + addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String()) + addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String()) + addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String()) + addModerationImage(images, value.Get("source.data").String()) + addModerationImage(images, value.Get("data").String()) + addModerationImage(images, value.Get("base64").String()) + switch typ { + case "", "text", "input_text", "message": + if value.Get("text").Exists() { + addModerationText(parts, value.Get("text").String()) + } + if value.Get("content").Exists() { + collectContentValue(value.Get("content"), parts, images) + } + case "image_url", "input_image", "image": + } + } +} + +func addGeminiModerationImage(images *[]string, part gjson.Result) { + if inlineData := part.Get("inline_data"); inlineData.IsObject() { + mimeType := strings.TrimSpace(inlineData.Get("mime_type").String()) + data := strings.TrimSpace(inlineData.Get("data").String()) + if mimeType != "" && data != "" { + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) + } + } + if inlineData := part.Get("inlineData"); inlineData.IsObject() { + mimeType := strings.TrimSpace(inlineData.Get("mimeType").String()) + data := strings.TrimSpace(inlineData.Get("data").String()) + if mimeType != "" && data != "" { + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) + } + } + addModerationImage(images, part.Get("file_data.file_uri").String()) + addModerationImage(images, part.Get("fileData.fileUri").String()) +} + +func addModerationImageData(images *[]string, mimeType string, data string) { + mimeType = strings.TrimSpace(mimeType) + data = strings.TrimSpace(data) + if mimeType == "" || data == "" { + return + } + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) +} + +func addModerationImage(images *[]string, image string) { + image = strings.TrimSpace(image) + if image == "" { + return + } + if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") { + *images = append(*images, image) + } +} + +func normalizeModerationImages(images []string) []string { + out := make([]string, 0, len(images)) + seen := make(map[string]struct{}, len(images)) + for _, image := range images { + image = strings.TrimSpace(image) + if image == "" { + continue + } + if _, ok := seen[image]; ok { + continue + } + seen[image] = struct{}{} + out = append(out, image) + } + return out +} + +func limitContentModerationImages(images []string) []string { + if len(images) <= maxContentModerationInputImages { + return images + } + idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(images)))) + if err != nil { + return images[:maxContentModerationInputImages] + } + return []string{images[int(idx.Int64())]} +} + +func addModerationText(parts *[]string, text string) { + text = strings.TrimSpace(text) + if text == "" { + return + } + if strings.Contains(text, "") { + return + } + *parts = append(*parts, text) +} + +func normalizeContentModerationText(text string) string { + return strings.Join(strings.Fields(strings.TrimSpace(text)), " ") +} diff --git a/backend/internal/service/content_moderation_redact.go b/backend/internal/service/content_moderation_redact.go new file mode 100644 index 00000000..473c8178 --- /dev/null +++ b/backend/internal/service/content_moderation_redact.go @@ -0,0 +1,37 @@ +package service + +import ( + "regexp" + "strings" +) + +var contentModerationSecretPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)\bhttps?://[^\s"'<>,。;、]+`), + regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`), + regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`), + regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`), + regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`), + regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`), + regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`), + regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`), + regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`), +} + +func redactContentModerationSecrets(text string) string { + text = strings.TrimSpace(text) + if text == "" { + return "" + } + out := text + for idx, pattern := range contentModerationSecretPatterns { + switch idx { + case 1: + out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`) + case 2: + out = pattern.ReplaceAllString(out, `${1}[已脱敏]`) + default: + out = pattern.ReplaceAllString(out, `[已脱敏]`) + } + } + return out +} diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go new file mode 100644 index 00000000..cef5127e --- /dev/null +++ b/backend/internal/service/content_moderation_test.go @@ -0,0 +1,1006 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type contentModerationTestSettingRepo struct { + values map[string]string +} + +func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) { + if value, ok := r.values[key]; ok { + return &Setting{Key: key, Value: value}, nil + } + return nil, ErrSettingNotFound +} + +func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if value, ok := r.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error { + if r.values == nil { + r.values = map[string]string{} + } + r.values[key] = value + return nil +} + +func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := map[string]string{} + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + if r.values == nil { + r.values = map[string]string{} + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.values, key) + return nil +} + +type contentModerationTestRepo struct { + logs []ContentModerationLog +} + +func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error { + if log != nil { + r.logs = append(r.logs, *log) + } + return nil +} + +func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + return 0, nil +} + +func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) { + return &ContentModerationCleanupResult{}, nil +} + +type contentModerationTestHashCache struct { + hashes map[string]struct{} + recorded []string + checked []string + deleted []string + hasResult bool + hasResultUsed bool +} + +type contentModerationTestUserRepo struct { + user *User + updated []User +} + +func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error { + panic("unexpected Create call") +} + +func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) { + if r.user == nil { + return nil, ErrUserNotFound + } + clone := *r.user + return &clone, nil +} + +func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) { + panic("unexpected GetByEmail call") +} + +func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error { + if user == nil { + return nil + } + clone := *user + r.updated = append(r.updated, clone) + r.user = &clone + return nil +} + +func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + +func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + +func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *contentModerationTestUserRepo) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) { + panic("unexpected BatchSetConcurrency call") +} + +func (r *contentModerationTestUserRepo) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) { + panic("unexpected BatchAddConcurrency call") +} + +func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error { + panic("unexpected UnbindUserAuthProvider call") +} + +func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} + +type contentModerationTestAuthCacheInvalidator struct { + userIDs []int64 +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) { +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + i.userIDs = append(i.userIDs, userID) +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { +} + +func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + if c.hashes == nil { + c.hashes = map[string]struct{}{} + } + c.hashes[inputHash] = struct{}{} + c.recorded = append(c.recorded, inputHash) + return nil +} + +func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.checked = append(c.checked, inputHash) + if c.hasResultUsed { + return c.hasResult, nil + } + _, ok := c.hashes[inputHash] + return ok, nil +} + +func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.deleted = append(c.deleted, inputHash) + if c.hashes == nil { + return false, nil + } + if _, ok := c.hashes[inputHash]; !ok { + return false, nil + } + delete(c.hashes, inputHash) + return true, nil +} + +func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + deleted := int64(len(c.hashes)) + c.hashes = map[string]struct{}{} + return deleted, nil +} + +func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + return int64(len(c.hashes)), nil +} + +func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) { + svc := &ContentModerationService{} + cfg := defaultContentModerationConfig() + input := ContentModerationCheckInput{ + RequestID: "req-1", + Endpoint: "/v1/chat/completions", + Provider: "openai", + } + + log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "") + + require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef") + require.Contains(t, log.InputExcerpt, "[已脱敏]") +} + +func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) { + input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart https://example.com/private/path?token=abc123" + + out := redactContentModerationSecrets(input) + + require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554") + require.NotContains(t, out, "abc123456789xyz") + require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9") + require.NotContains(t, out, "https://example.com/private/path") + require.Contains(t, out, "[已脱敏]") +} + +func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.NonHitRetentionDays = 30 + + cfg.normalize() + + require.Equal(t, 3, cfg.NonHitRetentionDays) +} + +func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.APIKeys = []string{"sk-old-a", "sk-old-b"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyContentModerationConfig: string(rawCfg), + }} + svc := NewContentModerationService(repo, nil, nil, nil, nil, nil, nil) + deleteHashes := []string{moderationAPIKeyHash("sk-old-a")} + addKeys := []string{"sk-new-c", "sk-old-b"} + + view, err := svc.UpdateConfig(context.Background(), UpdateContentModerationConfigInput{ + APIKeys: &addKeys, + DeleteAPIKeyHashes: &deleteHashes, + }) + + require.NoError(t, err) + require.Equal(t, 2, view.APIKeyCount) + require.Equal(t, []string{maskSecretTail("sk-old-b"), maskSecretTail("sk-new-c")}, view.APIKeyMasks) + + var saved ContentModerationConfig + require.NoError(t, json.Unmarshal([]byte(repo.values[SettingKeyContentModerationConfig]), &saved)) + require.Equal(t, []string{"sk-old-b", "sk-new-c"}, saved.apiKeys()) +} + +func TestContentModerationUpdateConfig_ReplacesAPIKeysWhenRequested(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.APIKeys = []string{"sk-old-a", "sk-old-b"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyContentModerationConfig: string(rawCfg), + }} + svc := NewContentModerationService(repo, nil, nil, nil, nil, nil, nil) + deleteHashes := []string{moderationAPIKeyHash("sk-old-a")} + replaceKeys := []string{"sk-new-only"} + + view, err := svc.UpdateConfig(context.Background(), UpdateContentModerationConfigInput{ + APIKeys: &replaceKeys, + APIKeysMode: contentModerationAPIKeysModeReplace, + DeleteAPIKeyHashes: &deleteHashes, + }) + + require.NoError(t, err) + require.Equal(t, 1, view.APIKeyCount) + require.Equal(t, []string{maskSecretTail("sk-new-only")}, view.APIKeyMasks) + + var saved ContentModerationConfig + require.NoError(t, json.Unmarshal([]byte(repo.values[SettingKeyContentModerationConfig]), &saved)) + require.Equal(t, []string{"sk-new-only"}, saved.apiKeys()) +} + +func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"old"}, + {"role":"assistant","content":"ok"}, + {"role":"user","content":[ + {"type":"text","text":"检查这张图"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}} + ]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + require.Equal(t, "检查这张图", input.Text) + require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images) + + log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "") + require.Equal(t, "检查这张图", log.InputExcerpt) + require.NotContains(t, log.InputExcerpt, "aGVsbG8=") +} + +func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) { + body := []byte(`{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "工具说明"}, + {"type": "text", "text": "Ainder>\n\n"}, + {"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}} + ] + } + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Equal(t, "hid", input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) { + body := []byte(`{ + "model":"gpt-5.5", + "messages":[ + {"role":"system","content":"system prompt"}, + {"role":"user","content":"old user"}, + {"role":"assistant","content":"ok"}, + {"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body) + + require.Equal(t, "latest user", input.Text) + require.Equal(t, []string{"https://example.com/a.png"}, input.Images) + require.NotContains(t, input.Text, "old user") + require.NotContains(t, input.Text, "system prompt") +} + +func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) { + body := []byte(`{ + "prompt":"replace background", + "images":[ + {"image_url":"https://example.com/source.png"}, + {"image_url":"data:image/png;base64,aGVsbG8="} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images) +} + +func TestContentModerationInput_NormalizeKeepsImagesAndModerationInputSamplesOneImage(t *testing.T) { + images := []string{ + "data:image/png;base64,Zmlyc3Q=", + "data:image/png;base64,c2Vjb25k", + } + input := ContentModerationInput{ + Text: "check image", + Images: append([]string(nil), images...), + } + input.Normalize() + + require.Equal(t, images, input.Images) + + parts, ok := input.ModerationInput().([]moderationAPIInputPart) + require.True(t, ok) + require.Len(t, parts, 2) + require.Equal(t, "text", parts[0].Type) + require.Equal(t, "image_url", parts[1].Type) + require.NotNil(t, parts[1].ImageURL) + require.Contains(t, images, parts[1].ImageURL.URL) +} + +func TestBuildModerationTestInputRejectsMultipleImages(t *testing.T) { + _, _, err := buildModerationTestInput("check image", []string{ + "data:image/png;base64,Zmlyc3Q=", + "data:image/png;base64,c2Vjb25k", + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "最多上传 1 张测试图片") +} + +func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) { + body := []byte(`{ + "model":"gpt-5.5", + "instructions":"instructions.....", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]} + ], + "prompt_cache_key":"cache-key" + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body) + + require.Equal(t, "last user prompt", input.Text) + require.Empty(t, input.Images) + require.NotContains(t, input.Text, "developer permissions") + require.NotContains(t, input.Text, "first user prompt") +} + +func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) { + var moderationRequest moderationAPIRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest)) + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.01}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.RecordNonHits = true + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{ + "model":"gpt-5.5", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]} + ] + }`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Endpoint: "/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIResponses, + Body: body, + }) + + require.NoError(t, err) + require.False(t, decision.Blocked) + require.Len(t, repo.logs, 1) + require.False(t, repo.logs[0].Flagged) + require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action) + require.Equal(t, "/responses", repo.logs[0].Endpoint) + require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt) + require.Equal(t, "last user prompt", moderationRequest.Input) +} + +func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) { + var moderationRequest moderationAPIRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest)) + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusUnavailableForLegalReasons + cfg.BlockMessage = "内容审计测试阻断" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{ + "model":"gpt-5.5", + "instructions":"instructions.....", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]} + ] + }`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Endpoint: "/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIResponses, + Body: body, + }) + + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionBlock, decision.Action) + require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode) + require.Equal(t, "内容审计测试阻断", decision.Message) + require.Len(t, repo.logs, 1) + require.True(t, repo.logs[0].Flagged) + require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action) + require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode) + require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt) + require.Equal(t, "latest blocked prompt", moderationRequest.Input) +} + +func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) { + result := buildContentModerationTestAuditResult(&moderationAPIResult{ + Flagged: true, + CategoryScores: map[string]float64{ + "harassment": 0.65, + }, + }, nil) + + require.NotNil(t, result) + require.False(t, result.Flagged) + require.Equal(t, "harassment", result.HighestCategory) + require.Equal(t, 0.65, result.HighestScore) + require.Equal(t, 0.65, result.CompositeScore) + require.Equal(t, 0.98, result.Thresholds["harassment"]) +} + +func TestContentModerationCallModeration_400DoesNotFreezeAPIKey(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"Number of images (5) exceeds maximum of 1","type":"invalid_request_error","param":"input","code":"too_many_images"}}`)) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.RetryCount = 5 + svc := NewContentModerationService(nil, nil, nil, nil, nil, nil, nil) + + _, err := svc.callModeration(context.Background(), cfg, "hello") + + require.Error(t, err) + require.Equal(t, 1, requestCount) + status := svc.apiKeyStatusForHash(0, moderationAPIKeyHash("sk-test"), maskSecretTail("sk-test"), true) + require.Equal(t, "error", status.Status) + require.Equal(t, http.StatusBadRequest, status.LastHTTPStatus) + require.Zero(t, status.FailureCount) + require.Nil(t, status.FrozenUntil) +} + +func TestContentModerationCallModeration_FreezesByHTTPStatus(t *testing.T) { + tests := []struct { + name string + statusCode int + minFreeze time.Duration + maxFreeze time.Duration + }{ + {name: "401 freezes ten minutes", statusCode: http.StatusUnauthorized, minFreeze: 9*time.Minute + 55*time.Second, maxFreeze: 10*time.Minute + time.Second}, + {name: "403 freezes ten minutes", statusCode: http.StatusForbidden, minFreeze: 9*time.Minute + 55*time.Second, maxFreeze: 10*time.Minute + time.Second}, + {name: "429 freezes one minute", statusCode: http.StatusTooManyRequests, minFreeze: 55 * time.Second, maxFreeze: time.Minute + time.Second}, + {name: "529 freezes one minute", statusCode: 529, minFreeze: 55 * time.Second, maxFreeze: time.Minute + time.Second}, + {name: "500 freezes ten seconds", statusCode: http.StatusInternalServerError, minFreeze: 5 * time.Second, maxFreeze: 11 * time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(`{"error":{"message":"upstream error"}}`)) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.RetryCount = 0 + svc := NewContentModerationService(nil, nil, nil, nil, nil, nil, nil) + + _, err := svc.callModeration(context.Background(), cfg, "hello") + + require.Error(t, err) + status := svc.apiKeyStatusForHash(0, moderationAPIKeyHash("sk-test"), maskSecretTail("sk-test"), true) + require.Equal(t, "frozen", status.Status) + require.Equal(t, tt.statusCode, status.LastHTTPStatus) + require.Equal(t, 1, status.FailureCount) + require.NotNil(t, status.FrozenUntil) + remaining := time.Until(*status.FrozenUntil) + require.GreaterOrEqual(t, remaining, tt.minFreeze) + require.LessOrEqual(t, remaining, tt.maxFreeze) + }) + } +} + +func TestContentModerationTestAPIKeys_400DoesNotFreezeAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"invalid moderation request"}}`)) + })) + defer server.Close() + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{}}, + nil, + nil, + nil, + nil, + nil, + nil, + ) + result, err := svc.TestAPIKeys(context.Background(), TestContentModerationAPIKeysInput{ + APIKeys: []string{"sk-test"}, + BaseURL: server.URL, + Prompt: "hello", + }) + + require.NoError(t, err) + require.Len(t, result.Items, 1) + require.Equal(t, "error", result.Items[0].Status) + require.Equal(t, http.StatusBadRequest, result.Items[0].LastHTTPStatus) + require.Zero(t, result.Items[0].FailureCount) + require.Nil(t, result.Items[0].FrozenUntil) +} + +func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.PreHashCheckEnabled = true + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusConflict + cfg.BlockMessage = "命中历史风险输入" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}} + content := ContentModerationInput{Text: "blocked prompt"} + content.Normalize() + hashCache.hashes[content.Hash()] = struct{}{} + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + &contentModerationTestRepo{}, + hashCache, + nil, + nil, + nil, + nil, + ) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionHashBlock, decision.Action) + require.Equal(t, http.StatusConflict, decision.StatusCode) + require.Equal(t, content.Hash(), decision.InputHash) + require.Contains(t, decision.Message, "命中历史风险输入") + require.Contains(t, decision.Message, content.Hash()) + require.Len(t, hashCache.checked, 1) +} + +func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.PreHashCheckEnabled = true + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusConflict + cfg.BlockMessage = "命中风险输入" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + hashCache := &contentModerationTestHashCache{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + hashCache, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: body, + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionBlock, decision.Action) + require.Equal(t, 1, requestCount) + require.Len(t, hashCache.recorded, 1) + require.Len(t, repo.logs, 1) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: body, + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionHashBlock, decision.Action) + require.Equal(t, hashCache.recorded[0], decision.InputHash) + require.Equal(t, 1, requestCount) + require.Len(t, repo.logs, 1) +} + +func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) { + existingHash := strings.Repeat("a", 64) + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{ + existingHash: {}, + }} + svc := &ContentModerationService{hashCache: hashCache} + + result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash)) + + require.NoError(t, err) + require.Equal(t, existingHash, result.InputHash) + require.True(t, result.Deleted) + require.NotContains(t, hashCache.hashes, existingHash) + require.Equal(t, []string{existingHash}, hashCache.deleted) + + result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash) + + require.NoError(t, err) + require.Equal(t, existingHash, result.InputHash) + require.False(t, result.Deleted) +} + +func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{ + strings.Repeat("a", 64): {}, + strings.Repeat("b", 64): {}, + }} + svc := &ContentModerationService{ + settingRepo: &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + hashCache: hashCache, + keyHealth: make(map[string]*contentModerationKeyHealth), + } + + status, err := svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), status.FlaggedHashCount) + + result, err := svc.ClearFlaggedInputHashes(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), result.Deleted) + + status, err = svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(0), status.FlaggedHashCount) +} + +func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModeObserve + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + hashCache := &contentModerationTestHashCache{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + hashCache, + nil, + nil, + nil, + nil, + ) + + decision := svc.checkSync(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`), + }, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false) + + require.False(t, decision.Blocked) + require.Len(t, hashCache.recorded, 1) + require.Len(t, repo.logs, 1) +} + +func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) { + userID := int64(1001) + cfg := defaultContentModerationConfig() + cfg.BanThreshold = 10 + body := buildContentModerationAccountDisabledEmailBody("Sub2API ", &ContentModerationLog{ + UserID: &userID, + UserEmail: "user@example.com", + GroupName: "vip_2", + HighestCategory: "sexual", + HighestScore: 0.926, + ViolationCount: 10, + }, cfg) + + require.Contains(t, body, "账户已被自动禁用") + require.Contains(t, body, "封禁详情") + require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝") + require.Contains(t, body, "10 次(阈值 10)") + require.Contains(t, body, "sexual / 0.926") + require.Contains(t, body, "Sub2API <Admin>") +} + +func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) { + userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}} + invalidator := &contentModerationTestAuthCacheInvalidator{} + repo := &contentModerationTestRepo{} + svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil) + + result, err := svc.UnbanUser(context.Background(), 1001) + + require.NoError(t, err) + require.Equal(t, int64(1001), result.UserID) + require.Equal(t, StatusActive, result.Status) + require.Len(t, userRepo.updated, 1) + require.Equal(t, StatusActive, userRepo.updated[0].Status) + require.Equal(t, []int64{1001}, invalidator.userIDs) +} + +func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) { + userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}} + invalidator := &contentModerationTestAuthCacheInvalidator{} + repo := &contentModerationTestRepo{} + svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil) + + result, err := svc.UnbanUser(context.Background(), 1001) + + require.NoError(t, err) + require.Equal(t, StatusActive, result.Status) + require.Empty(t, userRepo.updated) + require.Equal(t, []int64{1001}, invalidator.userIDs) +} + +func contentModerationIntPtr(v int) *int { + return &v +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 6d0c5fda..c37fa29f 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -108,6 +108,12 @@ const ( SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结) SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久) SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限) + SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路 + SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON) + SettingKeyLoginAgreementEnabled = "login_agreement_enabled" // 登录前是否要求同意条款 + SettingKeyLoginAgreementMode = "login_agreement_mode" // 条款确认展示模式:modal / checkbox + SettingKeyLoginAgreementUpdatedAt = "login_agreement_updated_at" // 条款更新日期(展示用) + SettingKeyLoginAgreementDocuments = "login_agreement_documents" // 条款文档列表(JSON,Markdown 内容) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -174,6 +180,18 @@ const ( SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path" SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path" + // GitHub / Google 邮箱快捷登录设置 + SettingKeyGitHubOAuthEnabled = "github_oauth_enabled" + SettingKeyGitHubOAuthClientID = "github_oauth_client_id" + SettingKeyGitHubOAuthClientSecret = "github_oauth_client_secret" + SettingKeyGitHubOAuthRedirectURL = "github_oauth_redirect_url" + SettingKeyGitHubOAuthFrontendRedirectURL = "github_oauth_frontend_redirect_url" + SettingKeyGoogleOAuthEnabled = "google_oauth_enabled" + SettingKeyGoogleOAuthClientID = "google_oauth_client_id" + SettingKeyGoogleOAuthClientSecret = "google_oauth_client_secret" + SettingKeyGoogleOAuthRedirectURL = "google_oauth_redirect_url" + SettingKeyGoogleOAuthFrontendRedirectURL = "google_oauth_frontend_redirect_url" + // OEM设置 SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) @@ -217,6 +235,16 @@ const ( SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance" + SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency" + SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions" + SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup" + SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind" + SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance" + SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency" + SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions" + SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup" + SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind" SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" // 管理员 API Key diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index ecaffe21..6919c148 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -124,6 +124,24 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { require.Contains(t, got, "fast-mode-2026-02-01") } +func TestFullClaudeCodeMimicryBetas_DoesNotDefaultRedactThinking(t *testing.T) { + required := claude.FullClaudeCodeMimicryBetas() + + require.NotContains(t, required, claude.BetaRedactThinking) + require.Contains(t, required, claude.BetaClaudeCode) + require.Contains(t, required, claude.BetaOAuth) + require.Contains(t, required, claude.BetaInterleavedThinking) +} + +func TestMergeAnthropicBetaDropping_PreservesIncomingRedactThinking(t *testing.T) { + required := claude.FullClaudeCodeMimicryBetas() + incoming := claude.BetaRedactThinking + + got := mergeAnthropicBetaDropping(required, incoming, droppedBetaSet()) + + require.Contains(t, got, claude.BetaRedactThinking) +} + func TestDroppedBetaSet(t *testing.T) { // Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy) base := droppedBetaSet() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a1f3f353..4b7f638f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5445,6 +5445,12 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( flusher.Flush() } if !sawTerminalEvent { + if clientDisconnected && streamInterval > 0 { + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) >= streamInterval { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index a5fe707d..e4430536 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -440,6 +440,21 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) } +func (s *OpenAIGatewayService) isCodexImageGenerationBridgeEnabled(ctx context.Context, account *Account, apiKey *APIKey) bool { + if override := account.CodexImageGenerationBridgeOverride(); override != nil { + return *override + } + if s != nil && s.channelService != nil && apiKey != nil && apiKey.GroupID != nil { + ch, err := s.channelService.GetChannelForGroup(ctx, *apiKey.GroupID) + if err != nil { + slog.Warn("failed to resolve codex image generation bridge channel override", "group_id", *apiKey.GroupID, "error", err) + } else if override := ch.CodexImageGenerationBridgeOverride(PlatformOpenAI); override != nil { + return *override + } + } + return s != nil && s.cfg != nil && s.cfg.Gateway.CodexImageGenerationBridgeEnabled +} + func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { if groupID == nil || s.channelService == nil || requestedModel == "" { return false @@ -2059,6 +2074,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if apiKey != nil { imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group) } + codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey) if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") c.JSON(http.StatusForbidden, gin.H{ @@ -2128,7 +2144,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } - if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) { + if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") @@ -2139,7 +2155,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") } - if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) { + if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") diff --git a/backend/internal/service/openai_image_generation_controls_test.go b/backend/internal/service/openai_image_generation_controls_test.go index 76dc8053..9ff8b510 100644 --- a/backend/internal/service/openai_image_generation_controls_test.go +++ b/backend/internal/service/openai_image_generation_controls_test.go @@ -83,12 +83,14 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability( gin.SetMode(gin.TestMode) tests := []struct { - name string - allowImages bool - wantInjected bool + name string + allowImages bool + bridgeEnabled bool + wantInjected bool }{ - {name: "disabled group skips injection", allowImages: false, wantInjected: false}, - {name: "enabled group injects image tool", allowImages: true, wantInjected: true}, + {name: "disabled group skips injection", allowImages: false, bridgeEnabled: true, wantInjected: false}, + {name: "enabled group skips injection by default", allowImages: true, bridgeEnabled: false, wantInjected: false}, + {name: "enabled group injects image tool when bridge enabled", allowImages: true, bridgeEnabled: true, wantInjected: true}, } for _, tt := range tests { @@ -101,6 +103,7 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability( }, } svc := newOpenAIImageGenerationControlTestService(upstream) + svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.bridgeEnabled c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0") account := newOpenAIImageGenerationControlTestAccount() @@ -117,6 +120,154 @@ func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability( } } +func TestOpenAIGatewayServiceForward_ExplicitImageToolWorksWithBridgeDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_explicit_image","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0") + account := newOpenAIImageGenerationControlTestAccount() + body := []byte(`{"model":"gpt-5.4","input":"draw","stream":false,"tools":[{"type":"image_generation","format":"jpeg"}]}`) + + result, err := svc.Forward(context.Background(), c, account, body) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()) + require.Equal(t, "jpeg", gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").output_format`).String()) + require.False(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation").format`).Exists()) + instructions := gjson.GetBytes(upstream.lastBody, "instructions").String() + require.NotContains(t, instructions, "image_generation") +} + +func TestOpenAIGatewayServiceForward_ChannelBridgeOverrideEnablesCodexInjection(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_channel_bridge","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + groupID := int64(4242) + svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, &Channel{ + ID: 9001, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true}, + }, + }) + c, _ := newOpenAIImageGenerationControlTestContext(true, "codex_cli_rs/0.98.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`)) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.True(t, gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()) + instructions := gjson.GetBytes(upstream.lastBody, "instructions").String() + require.Contains(t, instructions, "image_generation") +} + +func TestOpenAIGatewayService_CodexImageGenerationBridgeOverridePrecedence(t *testing.T) { + groupID := int64(4242) + + tests := []struct { + name string + global bool + channel *Channel + account *Account + want bool + }{ + { + name: "global default enables bridge", + global: true, + account: &Account{ + Platform: PlatformOpenAI, + }, + want: true, + }, + { + name: "channel true overrides disabled global", + global: false, + channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{ + featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true}, + }}, + account: &Account{Platform: PlatformOpenAI}, + want: true, + }, + { + name: "channel false overrides enabled global", + global: true, + channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{ + featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false}, + }}, + account: &Account{Platform: PlatformOpenAI}, + want: false, + }, + { + name: "account false overrides channel and global true", + global: true, + channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{ + featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: true}, + }}, + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{featureKeyCodexImageGenerationBridge: false}, + }, + want: false, + }, + { + name: "nested account true overrides channel false", + global: false, + channel: &Channel{ID: 1, Status: StatusActive, FeaturesConfig: map[string]any{ + featureKeyCodexImageGenerationBridge: map[string]any{PlatformOpenAI: false}, + }}, + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + PlatformOpenAI: map[string]any{"codex_image_generation_bridge_enabled": true}, + }, + }, + want: true, + }, + { + name: "non openai account extra is ignored", + global: false, + account: &Account{ + Platform: PlatformAnthropic, + Extra: map[string]any{featureKeyCodexImageGenerationBridge: true}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{}) + svc.cfg.Gateway.CodexImageGenerationBridgeEnabled = tt.global + if tt.channel != nil { + svc.channelService = newOpenAIImageGenerationControlChannelService(groupID, tt.channel) + } + apiKey := &APIKey{GroupID: &groupID} + + got := svc.isCodexImageGenerationBridgeEnabled(context.Background(), tt.account, apiKey) + + require.Equal(t, tt.want, got) + }) + } +} + func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) { gin.SetMode(gin.TestMode) @@ -180,6 +331,18 @@ func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) } } +func newOpenAIImageGenerationControlChannelService(groupID int64, ch *Channel) *ChannelService { + svc := &ChannelService{} + cache := newEmptyChannelCache() + if ch != nil { + cache.channelByGroupID[groupID] = ch + cache.byID[ch.ID] = ch + } + cache.loadedAt = time.Now() + svc.cache.Store(cache) + return svc +} + func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) { recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 04be5164..afa94156 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -90,6 +90,69 @@ type OpenAIImagesRequest struct { bodyHash string } +func (r *OpenAIImagesRequest) ModerationBody() []byte { + if r == nil { + return nil + } + payload := map[string]any{} + if prompt := strings.TrimSpace(r.Prompt); prompt != "" { + payload["prompt"] = prompt + } + images := r.moderationImages() + if len(images) > 0 { + payload["images"] = images + } + if len(payload) == 0 { + return nil + } + body, err := json.Marshal(payload) + if err != nil { + return nil + } + return body +} + +func (r *OpenAIImagesRequest) moderationImages() []map[string]string { + if r == nil { + return nil + } + images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1) + for _, imageURL := range r.InputImageURLs { + imageURL = strings.TrimSpace(imageURL) + if imageURL != "" { + images = append(images, map[string]string{"image_url": imageURL}) + } + } + for _, upload := range r.Uploads { + if dataURL := upload.ModerationDataURL(); dataURL != "" { + images = append(images, map[string]string{"image_url": dataURL}) + } + } + if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" { + images = append(images, map[string]string{"image_url": maskURL}) + } + if r.MaskUpload != nil { + if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" { + images = append(images, map[string]string{"image_url": dataURL}) + } + } + return images +} + +func (u OpenAIImagesUpload) ModerationDataURL() string { + if len(u.Data) == 0 { + return "" + } + contentType := strings.TrimSpace(u.ContentType) + if contentType == "" { + contentType = http.DetectContentType(u.Data) + } + if !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return "" + } + return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data)) +} + func (r *OpenAIImagesRequest) IsEdits() bool { return r != nil && r.Endpoint == openAIImagesEditsEndpoint } diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index fa4a4415..45fb24e9 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Prompt: "replace background", + InputImageURLs: []string{"https://example.com/source.png"}, + MaskImageURL: "https://example.com/mask.png", + } + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody()) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images) +} + +func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Prompt: "replace background", + Uploads: []OpenAIImagesUpload{{ + FieldName: "image", + FileName: "source.png", + ContentType: "image/png", + Data: []byte("fake-image-bytes"), + }}, + MaskUpload: &OpenAIImagesUpload{ + FieldName: "mask", + FileName: "mask.png", + ContentType: "image/png", + Data: []byte("fake-mask-bytes"), + }, + } + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody()) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{ + "data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==", + "data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz", + }, input.Images) + + log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "") + require.Equal(t, "replace background", log.InputExcerpt) + require.NotContains(t, log.InputExcerpt, "ZmFrZS") +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_oauth_service_refresh_test.go b/backend/internal/service/openai_oauth_service_refresh_test.go index a31eb8cb..84b68ea6 100644 --- a/backend/internal/service/openai_oauth_service_refresh_test.go +++ b/backend/internal/service/openai_oauth_service_refresh_test.go @@ -52,3 +52,47 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess require.Equal(t, "client-id-1", info.ClientID) require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh") } + +func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) { + refresher := NewOpenAITokenRefresher(nil, nil) + expiresAt := time.Now().Add(time.Minute).UTC().Format(time.RFC3339) + + withoutRT := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": expiresAt, + }, + } + require.False(t, refresher.NeedsRefresh(withoutRT, 5*time.Minute)) + + withRT := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, + }, + } + require.True(t, refresher.NeedsRefresh(withRT, 5*time.Minute)) +} + +func TestOpenAITokenProvider_NoRefreshTokenExpiredAccessTokenReturnsError(t *testing.T) { + provider := NewOpenAITokenProvider(nil, nil, nil) + expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "expired-access-token", + "expires_at": expiresAt, + }, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Empty(t, token) + require.Contains(t, err.Error(), "refresh_token is missing") +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index e438588e..a680d451 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -152,6 +152,12 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew + if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" { + if expiresAt != nil && !time.Now().Before(*expiresAt) { + return "", errors.New("openai access_token expired and refresh_token is missing") + } + needsRefresh = false + } refreshFailed := false if needsRefresh && p.refreshAPI != nil && p.executor != nil { diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index e81fb465..4b69db8a 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -424,8 +424,9 @@ func TestOpenAITokenProvider_CacheGetError(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } @@ -650,8 +651,9 @@ func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } @@ -819,8 +821,9 @@ func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } @@ -848,8 +851,9 @@ func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } @@ -875,8 +879,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } cacheKey := OpenAITokenCacheKey(account) @@ -911,8 +916,9 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "fallback-token", - "expires_at": expiresAt, + "access_token": "fallback-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, }, } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 784cdbe5..372f420f 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct { // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。 InitialRequestModel string BeforeTurn func(turn int) error + BeforeRequest func(turn int, payload []byte, originalModel string) error AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) } @@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return true } for { + if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil { + if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil { + return err + } + } if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { if err := hooks.BeforeTurn(turn); err != nil { return err diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 8bc17d42..e2760725 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( if msgType != coderws.MessageText { return payload, nil, nil } + if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil { + turnNo := int(completedTurns.Load()) + 1 + if turnNo < 2 { + turnNo = 2 + } + requestModel := usageMeta.requestModelForFrame(payload) + if requestModel == "" { + requestModel = capturedSessionModel + } + if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil { + return payload, nil, err + } + } // 在评估策略前先刷新 capturedSessionModel:客户端可能通过 // session.update 修改 session-level model(Realtime / // Responses WS 协议允许),如果不刷新就会出现 diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 4ae6d134..f96684a4 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e case redeemActionRedeem: // Code exists but unused — skip creation, proceed to redeem } - if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { + if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index 8dfd2e7e..d8595715 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 9ced6201..dcf293c5 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -11,6 +11,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -28,6 +29,15 @@ const ( redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 ) +type ctxKeySkipRedeemAffiliate struct{} + +// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level +// affiliate rebate. Used by payment fulfillment which handles rebate separately +// via applyAffiliateRebateForOrder (with audit-log deduplication). +func ContextSkipRedeemAffiliate(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true) +} + // RedeemCache defines cache operations for redeem service type RedeemCache interface { GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) @@ -80,6 +90,7 @@ type RedeemService struct { billingCacheService *BillingCacheService entClient *dbent.Client authCacheInvalidator APIKeyAuthCacheInvalidator + affiliateService *AffiliateService } // NewRedeemService 创建兑换码服务实例 @@ -91,6 +102,7 @@ func NewRedeemService( billingCacheService *BillingCacheService, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator, + affiliateService *AffiliateService, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, @@ -100,6 +112,7 @@ func NewRedeemService( billingCacheService: billingCacheService, entClient: entClient, authCacheInvalidator: authCacheInvalidator, + affiliateService: affiliateService, } } @@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 事务提交成功后失效缓存 s.invalidateRedeemCaches(ctx, userID, redeemCode) + // 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果) + if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 { + s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value) + } + // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { @@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64 } } +func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) { + if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil { + return + } + if s.affiliateService == nil { + return + } + if !s.affiliateService.IsEnabled(ctx) { + return + } + rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount) + if err != nil { + logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err) + return + } + if rebate > 0 { + logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID) + } +} + // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a5d65ad7..283a239b 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -3,6 +3,7 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/hex" "encoding/json" "errors" @@ -129,6 +130,8 @@ type AuthSourceDefaultSettings struct { LinuxDo ProviderDefaultGrantSettings OIDC ProviderDefaultGrantSettings WeChat ProviderDefaultGrantSettings + GitHub ProviderDefaultGrantSettings + Google ProviderDefaultGrantSettings ForceEmailOnThirdPartySignup bool } @@ -169,6 +172,20 @@ var ( grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, } + gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultGitHubBalance, + concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency, + subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind, + } + googleAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultGoogleBalance, + concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency, + subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, + } ) const ( @@ -177,8 +194,151 @@ const ( defaultWeChatConnectMode = "open" defaultWeChatConnectScopes = "snsapi_login" defaultWeChatConnectFrontend = "/auth/wechat/callback" + defaultGitHubOAuthAuthorize = "https://github.com/login/oauth/authorize" + defaultGitHubOAuthToken = "https://github.com/login/oauth/access_token" + defaultGitHubOAuthUserInfo = "https://api.github.com/user" + defaultGitHubOAuthEmails = "https://api.github.com/user/emails" + defaultGitHubOAuthScopes = "read:user user:email" + defaultGitHubOAuthFrontend = "/auth/oauth/callback" + defaultGoogleOAuthAuthorize = "https://accounts.google.com/o/oauth2/v2/auth" + defaultGoogleOAuthToken = "https://oauth2.googleapis.com/token" + defaultGoogleOAuthUserInfo = "https://openidconnect.googleapis.com/v1/userinfo" + defaultGoogleOAuthScopes = "openid email profile" + defaultGoogleOAuthFrontend = "/auth/oauth/callback" + defaultLoginAgreementMode = "modal" + defaultLoginAgreementDate = "2026-03-31" ) +func normalizeLoginAgreementMode(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "checkbox": + return "checkbox" + default: + return defaultLoginAgreementMode + } +} + +func defaultLoginAgreementDocuments() []LoginAgreementDocument { + return []LoginAgreementDocument{ + { + ID: "terms", + Title: "服务条款", + ContentMD: "", + }, + { + ID: "usage-policy", + Title: "使用政策", + ContentMD: "", + }, + { + ID: "supported-regions", + Title: "支持的国家和地区", + ContentMD: "", + }, + { + ID: "service-specific-terms", + Title: "服务特定条款", + ContentMD: "", + }, + } +} + +func normalizeLoginAgreementDocumentID(raw string) string { + raw = strings.ToLower(strings.TrimSpace(raw)) + var b strings.Builder + lastSeparator := false + for _, r := range raw { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + _, _ = b.WriteRune(r) + lastSeparator = false + continue + } + if r == '-' || r == '_' || r == ' ' || r == '.' || r == '/' { + if !lastSeparator && b.Len() > 0 { + if r == '_' { + _, _ = b.WriteRune('_') + } else { + _, _ = b.WriteRune('-') + } + lastSeparator = true + } + } + } + return strings.Trim(b.String(), "-_") +} + +func normalizeLoginAgreementDocuments(docs []LoginAgreementDocument) []LoginAgreementDocument { + normalized := make([]LoginAgreementDocument, 0, len(docs)) + seen := make(map[string]int, len(docs)) + for i, doc := range docs { + title := strings.TrimSpace(doc.Title) + content := strings.TrimSpace(doc.ContentMD) + if title == "" && content == "" { + continue + } + id := normalizeLoginAgreementDocumentID(doc.ID) + if id == "" { + sum := sha256.Sum256([]byte(fmt.Sprintf("%d:%s:%s", i, title, content))) + id = hex.EncodeToString(sum[:])[:12] + } + baseID := id + for suffix := 2; seen[id] > 0; suffix++ { + id = fmt.Sprintf("%s-%d", baseID, suffix) + } + seen[id]++ + normalized = append(normalized, LoginAgreementDocument{ + ID: id, + Title: title, + ContentMD: content, + }) + } + return normalized +} + +func parseLoginAgreementDocuments(raw string) []LoginAgreementDocument { + raw = strings.TrimSpace(raw) + if raw == "" { + return defaultLoginAgreementDocuments() + } + var docs []LoginAgreementDocument + if err := json.Unmarshal([]byte(raw), &docs); err != nil { + return defaultLoginAgreementDocuments() + } + docs = normalizeLoginAgreementDocuments(docs) + if len(docs) == 0 { + return defaultLoginAgreementDocuments() + } + return docs +} + +func marshalLoginAgreementDocuments(docs []LoginAgreementDocument) (string, error) { + normalized := normalizeLoginAgreementDocuments(docs) + if len(normalized) == 0 { + normalized = defaultLoginAgreementDocuments() + } + b, err := json.Marshal(normalized) + if err != nil { + return "", fmt.Errorf("marshal login agreement documents: %w", err) + } + return string(b), nil +} + +func buildLoginAgreementRevision(updatedAt string, docs []LoginAgreementDocument) string { + normalized := normalizeLoginAgreementDocuments(docs) + payload, err := json.Marshal(struct { + UpdatedAt string `json:"updated_at"` + Documents []LoginAgreementDocument `json:"documents"` + }{ + UpdatedAt: strings.TrimSpace(updatedAt), + Documents: normalized, + }) + if err != nil { + payload = []byte(strings.TrimSpace(updatedAt)) + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:])[:16] +} + func normalizeWeChatConnectModeSetting(raw string) string { switch strings.ToLower(strings.TrimSpace(raw)) { case "mp": @@ -411,6 +571,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPasswordResetEnabled, SettingKeyInvitationCodeEnabled, SettingKeyTotpEnabled, + SettingKeyLoginAgreementEnabled, + SettingKeyLoginAgreementMode, + SettingKeyLoginAgreementUpdatedAt, + SettingKeyLoginAgreementDocuments, SettingKeyTurnstileEnabled, SettingKeyTurnstileSiteKey, SettingKeySiteName, @@ -448,6 +612,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, SettingKeyOIDCConnectProviderName, + SettingKeyGitHubOAuthEnabled, + SettingKeyGitHubOAuthClientID, + SettingKeyGitHubOAuthClientSecret, + SettingKeyGoogleOAuthEnabled, + SettingKeyGoogleOAuthClientID, + SettingKeyGoogleOAuthClientSecret, SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL, @@ -456,6 +626,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyChannelMonitorDefaultIntervalSeconds, SettingKeyAvailableChannelsEnabled, SettingKeyAffiliateEnabled, + SettingKeyRiskControlEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -482,6 +653,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github") + googleEnabled := s.emailOAuthPublicEnabled(settings, "google") weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) // Password reset requires email verification to be enabled @@ -494,6 +667,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings settings[SettingKeyTableDefaultPageSize], settings[SettingKeyTablePageSizeOptions], ) + loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) + loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) + if loginAgreementUpdatedAt == "" { + loginAgreementUpdatedAt = defaultLoginAgreementDate + } var balanceLowNotifyThreshold float64 if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { @@ -509,6 +687,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PasswordResetEnabled: passwordResetEnabled, InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0, + LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments), + LoginAgreementDocuments: loginAgreementDocuments, TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), @@ -534,6 +717,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, OIDCOAuthProviderName: oidcProviderName, + GitHubOAuthEnabled: gitHubEnabled, + GoogleOAuthEnabled: googleEnabled, BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true", AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true", BalanceLowNotifyThreshold: balanceLowNotifyThreshold, @@ -545,6 +730,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true", + + RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true", }, nil } @@ -647,43 +834,50 @@ func (s *SettingService) SetVersion(version string) { // A unit test diffs this struct's JSON keys against dto.PublicSettings to catch // drift automatically (see setting_service_injection_test.go). type PublicSettingsInjectionPayload struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - TableDefaultPageSize int `json:"table_default_page_size"` - TablePageSizeOptions []int `json:"table_page_size_options"` - CustomMenuItems json.RawMessage `json:"custom_menu_items"` - CustomEndpoints json.RawMessage `json:"custom_endpoints"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` - WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` - WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` - WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` - OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` - OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` - BackendModeEnabled bool `json:"backend_mode_enabled"` - PaymentEnabled bool `json:"payment_enabled"` - Version string `json:"version"` - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + LoginAgreementEnabled bool `json:"login_agreement_enabled"` + LoginAgreementMode string `json:"login_agreement_mode"` + LoginAgreementUpdatedAt string `json:"login_agreement_updated_at"` + LoginAgreementRevision string `json:"login_agreement_revision"` + LoginAgreementDocuments []LoginAgreementDocument `json:"login_agreement_documents"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + CustomEndpoints json.RawMessage `json:"custom_endpoints"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` + OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` + OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` + GitHubOAuthEnabled bool `json:"github_oauth_enabled"` + GoogleOAuthEnabled bool `json:"google_oauth_enabled"` + BackendModeEnabled bool `json:"backend_mode_enabled"` + PaymentEnabled bool `json:"payment_enabled"` + Version string `json:"version"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` // Feature flags — MUST match the opt-in/opt-out registry in // frontend/src/utils/featureFlags.ts. Missing a field here is the bug @@ -692,6 +886,7 @@ type PublicSettingsInjectionPayload struct { ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` AvailableChannelsEnabled bool `json:"available_channels_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` } // GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection. @@ -710,6 +905,11 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PasswordResetEnabled: settings.PasswordResetEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, + LoginAgreementEnabled: settings.LoginAgreementEnabled, + LoginAgreementMode: settings.LoginAgreementMode, + LoginAgreementUpdatedAt: settings.LoginAgreementUpdatedAt, + LoginAgreementRevision: settings.LoginAgreementRevision, + LoginAgreementDocuments: settings.LoginAgreementDocuments, TurnstileEnabled: settings.TurnstileEnabled, TurnstileSiteKey: settings.TurnstileSiteKey, SiteName: settings.SiteName, @@ -733,6 +933,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, + GitHubOAuthEnabled: settings.GitHubOAuthEnabled, + GoogleOAuthEnabled: settings.GoogleOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, Version: s.version, @@ -745,6 +947,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, AffiliateEnabled: settings.AffiliateEnabled, + RiskControlEnabled: settings.RiskControlEnabled, }, nil } @@ -806,6 +1009,98 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string return openReady || mpReady, openReady, mpReady, mobileReady } +func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "github": + cfg := config.EmailOAuthProviderConfig{ + AuthorizeURL: defaultGitHubOAuthAuthorize, + TokenURL: defaultGitHubOAuthToken, + UserInfoURL: defaultGitHubOAuthUserInfo, + EmailsURL: defaultGitHubOAuthEmails, + Scopes: defaultGitHubOAuthScopes, + FrontendRedirectURL: defaultGitHubOAuthFrontend, + } + if s != nil && s.cfg != nil { + cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GitHubOAuth) + } + return cfg + case "google": + cfg := config.EmailOAuthProviderConfig{ + AuthorizeURL: defaultGoogleOAuthAuthorize, + TokenURL: defaultGoogleOAuthToken, + UserInfoURL: defaultGoogleOAuthUserInfo, + Scopes: defaultGoogleOAuthScopes, + FrontendRedirectURL: defaultGoogleOAuthFrontend, + } + if s != nil && s.cfg != nil { + cfg = mergeEmailOAuthBaseConfig(cfg, s.cfg.GoogleOAuth) + } + return cfg + default: + return config.EmailOAuthProviderConfig{} + } +} + +func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) config.EmailOAuthProviderConfig { + base.Enabled = override.Enabled + if strings.TrimSpace(override.ClientID) != "" { + base.ClientID = strings.TrimSpace(override.ClientID) + } + if strings.TrimSpace(override.ClientSecret) != "" { + base.ClientSecret = strings.TrimSpace(override.ClientSecret) + } + if strings.TrimSpace(override.AuthorizeURL) != "" { + base.AuthorizeURL = strings.TrimSpace(override.AuthorizeURL) + } + if strings.TrimSpace(override.TokenURL) != "" { + base.TokenURL = strings.TrimSpace(override.TokenURL) + } + if strings.TrimSpace(override.UserInfoURL) != "" { + base.UserInfoURL = strings.TrimSpace(override.UserInfoURL) + } + if strings.TrimSpace(override.EmailsURL) != "" { + base.EmailsURL = strings.TrimSpace(override.EmailsURL) + } + if strings.TrimSpace(override.Scopes) != "" { + base.Scopes = strings.TrimSpace(override.Scopes) + } + if strings.TrimSpace(override.RedirectURL) != "" { + base.RedirectURL = strings.TrimSpace(override.RedirectURL) + } + if strings.TrimSpace(override.FrontendRedirectURL) != "" { + base.FrontendRedirectURL = strings.TrimSpace(override.FrontendRedirectURL) + } + return base +} + +func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool { + cfg := s.effectiveEmailOAuthConfig(settings, provider) + return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != "" +} + +func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig { + cfg := s.emailOAuthBaseConfig(provider) + switch strings.ToLower(strings.TrimSpace(provider)) { + case "github": + if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok { + cfg.Enabled = raw == "true" + } + cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID) + cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret) + cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL) + cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend) + case "google": + if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok { + cfg.Enabled = raw == "true" + } + cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID) + cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret) + cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL) + cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend) + } + return cfg +} + // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -1052,6 +1347,16 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting if settings.WeChatConnectFrontendRedirectURL == "" { settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend } + settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL) + settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL) + if settings.GitHubOAuthFrontendRedirectURL == "" { + settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend + } + settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL) + settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL) + if settings.GoogleOAuthFrontendRedirectURL == "" { + settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend + } updates := make(map[string]string) @@ -1068,6 +1373,19 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyFrontendURL] = settings.FrontendURL updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) + settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode) + settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt) + if settings.LoginAgreementUpdatedAt == "" { + settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate + } + loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments) + if err != nil { + return nil, err + } + updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled) + updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode + updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt + updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON // 邮件服务设置(只有非空才更新密码) updates[SettingKeySMTPHost] = settings.SMTPHost @@ -1121,6 +1439,22 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret } + // GitHub / Google 邮箱快捷登录 + updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled) + updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID) + updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL + updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL + if settings.GitHubOAuthClientSecret != "" { + updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret) + } + updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled) + updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID) + updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL + updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL + if settings.GoogleOAuthClientSecret != "" { + updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret) + } + // WeChat Connect OAuth 登录 updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID @@ -1232,6 +1566,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // Affiliate (邀请返利) feature switch updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled) + // 风控中心功能开关 + updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled) + // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion @@ -1273,17 +1610,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett settings.LinuxDo.Subscriptions, settings.OIDC.Subscriptions, settings.WeChat.Subscriptions, + settings.GitHub.Subscriptions, + settings.Google.Subscriptions, } { if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { return nil, err } } - updates := make(map[string]string, 21) + updates := make(map[string]string, 31) writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub) + writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google) updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) return updates, nil } @@ -1362,6 +1703,61 @@ func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, return nil } +func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider != "github" && provider != "google" { + return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_PROVIDER_NOT_FOUND", "oauth provider not found") + } + keys := []string{ + SettingKeyGitHubOAuthEnabled, + SettingKeyGitHubOAuthClientID, + SettingKeyGitHubOAuthClientSecret, + SettingKeyGitHubOAuthRedirectURL, + SettingKeyGitHubOAuthFrontendRedirectURL, + SettingKeyGoogleOAuthEnabled, + SettingKeyGoogleOAuthClientID, + SettingKeyGoogleOAuthClientSecret, + SettingKeyGoogleOAuthRedirectURL, + SettingKeyGoogleOAuthFrontendRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.EmailOAuthProviderConfig{}, fmt.Errorf("get email oauth settings: %w", err) + } + cfg := s.effectiveEmailOAuthConfig(settings, provider) + if !cfg.Enabled { + return config.EmailOAuthProviderConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + if strings.TrimSpace(cfg.ClientID) == "" { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured") + } + if strings.TrimSpace(cfg.ClientSecret) == "" { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") + } + for label, rawURL := range map[string]string{ + "authorize": cfg.AuthorizeURL, + "token": cfg.TokenURL, + "userinfo": cfg.UserInfoURL, + "redirect": cfg.RedirectURL, + } { + if strings.TrimSpace(rawURL) == "" { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url not configured") + } + if err := config.ValidateAbsoluteHTTPURL(rawURL); err != nil { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth "+label+" url invalid") + } + } + if strings.TrimSpace(cfg.EmailsURL) != "" { + if err := config.ValidateAbsoluteHTTPURL(cfg.EmailsURL); err != nil { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth emails url invalid") + } + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return config.EmailOAuthProviderConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid") + } + return cfg, nil +} + // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) @@ -1534,6 +1930,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool { return value == "true" } +// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting. +func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string { + value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems) + if err != nil { + return "[]" + } + return value +} + // IsAffiliateEnabled 检查是否启用邀请返利功能(总开关) func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled) @@ -1711,6 +2116,16 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut SettingKeyAuthSourceDefaultWeChatSubscriptions, SettingKeyAuthSourceDefaultWeChatGrantOnSignup, SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyAuthSourceDefaultGitHubBalance, + SettingKeyAuthSourceDefaultGitHubConcurrency, + SettingKeyAuthSourceDefaultGitHubSubscriptions, + SettingKeyAuthSourceDefaultGitHubGrantOnSignup, + SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind, + SettingKeyAuthSourceDefaultGoogleBalance, + SettingKeyAuthSourceDefaultGoogleConcurrency, + SettingKeyAuthSourceDefaultGoogleSubscriptions, + SettingKeyAuthSourceDefaultGoogleGrantOnSignup, + SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, SettingKeyForceEmailOnThirdPartySignup, } @@ -1724,6 +2139,8 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys), + Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys), ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", }, nil } @@ -1793,6 +2210,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken } } + loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(defaultLoginAgreementDocuments()) + if err != nil { + return err + } // 初始化默认设置 defaults := map[string]string{ @@ -1800,6 +2221,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyEmailVerifyEnabled: "false", SettingKeyRegistrationEmailSuffixWhitelist: "[]", SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeyLoginAgreementEnabled: "false", + SettingKeyLoginAgreementMode: defaultLoginAgreementMode, + SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate, + SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON, SettingKeySiteName: "Sub2API", SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", @@ -1824,6 +2249,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyWeChatConnectScopes: "snsapi_login", SettingKeyWeChatConnectRedirectURL: "", SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, + SettingKeyGitHubOAuthEnabled: "false", + SettingKeyGitHubOAuthClientID: "", + SettingKeyGitHubOAuthClientSecret: "", + SettingKeyGitHubOAuthRedirectURL: "", + SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend, + SettingKeyGoogleOAuthEnabled: "false", + SettingKeyGoogleOAuthClientID: "", + SettingKeyGoogleOAuthClientSecret: "", + SettingKeyGoogleOAuthRedirectURL: "", + SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend, SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectProviderName: "OIDC", SettingKeyOIDCConnectClientID: "", @@ -1874,6 +2309,16 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false", SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultGitHubBalance: "0", + SettingKeyAuthSourceDefaultGitHubConcurrency: "5", + SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]", + SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false", + SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultGoogleBalance: "0", + SettingKeyAuthSourceDefaultGoogleConcurrency: "5", + SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]", + SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false", + SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false", SettingKeyForceEmailOnThirdPartySignup: "false", SettingKeySMTPPort: "587", SettingKeySMTPUseTLS: "false", @@ -1903,6 +2348,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // Affiliate (邀请返利) feature (default disabled; opt-in) SettingKeyAffiliateEnabled: "false", + // 风控中心功能(默认关闭,显式启用) + SettingKeyRiskControlEnabled: "false", + // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", SettingKeyMaxClaudeCodeVersion: "", @@ -1923,6 +2371,11 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // parseSettings 解析设置到结构体 func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" + loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) + loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) + if loginAgreementUpdatedAt == "" { + loginAgreementUpdatedAt = defaultLoginAgreementDate + } result := &SystemSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, @@ -1932,6 +2385,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin FrontendURL: settings[SettingKeyFrontendURL], InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true", + LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocuments, SMTPHost: settings[SettingKeySMTPHost], SMTPUsername: settings[SettingKeySMTPUsername], SMTPFrom: settings[SettingKeySMTPFrom], @@ -2173,6 +2630,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github") + result.GitHubOAuthEnabled = gitHubEffective.Enabled + result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID) + result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret) + result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != "" + result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL) + result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL) + + googleEffective := s.effectiveEmailOAuthConfig(settings, "google") + result.GoogleOAuthEnabled = googleEffective.Enabled + result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID) + result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret) + result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != "" + result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL) + result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL) + // WeChat Connect 设置: // - 优先读取 DB 系统设置 // - 缺失时回退到 config/env,保持升级兼容 @@ -2242,6 +2715,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Affiliate (邀请返利) feature (default: disabled; strict true) result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true" + // 风控中心功能(默认关闭,严格 true 才启用) + result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true" + // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index aaf837bd..80b8b32a 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -20,6 +20,10 @@ type SystemSettings struct { FrontendURL string InvitationCodeEnabled bool TotpEnabled bool // TOTP 双因素认证 + LoginAgreementEnabled bool + LoginAgreementMode string + LoginAgreementUpdatedAt string + LoginAgreementDocuments []LoginAgreementDocument SMTPHost string SMTPPort int @@ -89,6 +93,20 @@ type SystemSettings struct { OIDCConnectUserInfoIDPath string OIDCConnectUserInfoUsernamePath string + // GitHub / Google 邮箱快捷登录 + GitHubOAuthEnabled bool + GitHubOAuthClientID string + GitHubOAuthClientSecret string + GitHubOAuthClientSecretConfigured bool + GitHubOAuthRedirectURL string + GitHubOAuthFrontendRedirectURL string + GoogleOAuthEnabled bool + GoogleOAuthClientID string + GoogleOAuthClientSecret string + GoogleOAuthClientSecretConfigured bool + GoogleOAuthRedirectURL string + GoogleOAuthFrontendRedirectURL string + SiteName string SiteLogo string SiteSubtitle string @@ -106,6 +124,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + RiskControlEnabled bool AffiliateEnabled bool AffiliateRebateRate float64 AffiliateRebateFreezeHours int @@ -190,6 +209,11 @@ type PublicSettings struct { PasswordResetEnabled bool InvitationCodeEnabled bool TotpEnabled bool // TOTP 双因素认证 + LoginAgreementEnabled bool + LoginAgreementMode string + LoginAgreementUpdatedAt string + LoginAgreementRevision string + LoginAgreementDocuments []LoginAgreementDocument TurnstileEnabled bool TurnstileSiteKey string SiteName string @@ -217,6 +241,8 @@ type PublicSettings struct { PaymentEnabled bool OIDCOAuthEnabled bool OIDCOAuthProviderName string + GitHubOAuthEnabled bool + GoogleOAuthEnabled bool Version string BalanceLowNotifyEnabled bool @@ -233,6 +259,15 @@ type PublicSettings struct { // Affiliate (邀请返利) feature toggle AffiliateEnabled bool `json:"affiliate_enabled"` + + // 风控中心功能开关 + RiskControlEnabled bool `json:"risk_control_enabled"` +} + +type LoginAgreementDocument struct { + ID string `json:"id"` + Title string `json:"title"` + ContentMD string `json:"content_md"` } type WeChatConnectOAuthConfig struct { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 916c2267..823f9812 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -2,6 +2,7 @@ package service import ( "context" + "strings" "time" ) @@ -95,6 +96,9 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { // NeedsRefresh 检查token是否需要刷新 // expires_at 缺失且处于限流状态时需要刷新,防止限流期间 token 静默过期 func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { + if strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" { + return false + } expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { return account.IsRateLimited() diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 193c0430..208a05db 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -97,6 +97,8 @@ type UserRepository interface { UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error UpdateConcurrency(ctx context.Context, id int64, amount int) error + BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) + BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) // AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index ff55c2a5..775dd602 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } + +func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { out := make([]UserAuthIdentityRecord, len(m.identities)) diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 699307e4..f424fb00 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -515,6 +515,7 @@ var ProviderSet = wire.NewSet( NewGroupCapacityService, NewChannelService, NewModelPricingResolver, + NewContentModerationService, NewAffiliateService, ProvidePaymentConfigService, NewPaymentService, diff --git a/backend/migrations/135_allow_email_oauth_provider_types.sql b/backend/migrations/135_allow_email_oauth_provider_types.sql new file mode 100644 index 00000000..a04edd7c --- /dev/null +++ b/backend/migrations/135_allow_email_oauth_provider_types.sql @@ -0,0 +1,27 @@ +ALTER TABLE users + DROP CONSTRAINT IF EXISTS users_signup_source_check; + +ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google')); + +ALTER TABLE auth_identities + DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check; + +ALTER TABLE auth_identities + ADD CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google')); + +ALTER TABLE auth_identity_channels + DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check; + +ALTER TABLE auth_identity_channels + ADD CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google')); + +ALTER TABLE pending_auth_sessions + DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check; + +ALTER TABLE pending_auth_sessions + ADD CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google')); diff --git a/backend/migrations/135_content_moderation.sql b/backend/migrations/135_content_moderation.sql new file mode 100644 index 00000000..4873bbf2 --- /dev/null +++ b/backend/migrations/135_content_moderation.sql @@ -0,0 +1,45 @@ +-- 风控中心内容审计配置与记录 + +INSERT INTO settings (key, value, updated_at) +VALUES ('risk_control_enabled', 'false', NOW()) +ON CONFLICT (key) DO NOTHING; + +CREATE TABLE IF NOT EXISTS content_moderation_logs ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(128) NOT NULL DEFAULT '', + user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, + user_email VARCHAR(255) NOT NULL DEFAULT '', + api_key_id BIGINT REFERENCES api_keys(id) ON DELETE SET NULL, + api_key_name VARCHAR(100) NOT NULL DEFAULT '', + group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL, + group_name VARCHAR(255) NOT NULL DEFAULT '', + endpoint VARCHAR(128) NOT NULL DEFAULT '', + provider VARCHAR(64) NOT NULL DEFAULT '', + model VARCHAR(255) NOT NULL DEFAULT '', + mode VARCHAR(32) NOT NULL DEFAULT '', + action VARCHAR(32) NOT NULL DEFAULT '', + flagged BOOLEAN NOT NULL DEFAULT FALSE, + highest_category VARCHAR(64) NOT NULL DEFAULT '', + highest_score DECIMAL(8, 6) NOT NULL DEFAULT 0, + category_scores JSONB NOT NULL DEFAULT '{}'::jsonb, + threshold_snapshot JSONB NOT NULL DEFAULT '{}'::jsonb, + input_excerpt TEXT NOT NULL DEFAULT '', + upstream_latency_ms INT, + error TEXT NOT NULL DEFAULT '', + violation_count INT NOT NULL DEFAULT 0, + auto_banned BOOLEAN NOT NULL DEFAULT FALSE, + email_sent BOOLEAN NOT NULL DEFAULT FALSE, + queue_delay_ms INT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS violation_count INT NOT NULL DEFAULT 0; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS auto_banned BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS email_sent BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS queue_delay_ms INT; +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_created_at ON content_moderation_logs(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_group_created_at ON content_moderation_logs(group_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_flagged_created_at ON content_moderation_logs(flagged, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_user_created_at ON content_moderation_logs(user_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_api_key_created_at ON content_moderation_logs(api_key_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_endpoint_created_at ON content_moderation_logs(endpoint, created_at DESC); diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 99216296..d122a3b7 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -142,3 +142,16 @@ func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count") require.NotContains(t, sql, "detail::jsonb") } + +func TestMigration135AllowsGitHubAndGoogleAuthProviders(t *testing.T) { + content, err := FS.ReadFile("135_allow_email_oauth_provider_types.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "users_signup_source_check") + require.Contains(t, sql, "auth_identities_provider_type_check") + require.Contains(t, sql, "auth_identity_channels_provider_type_check") + require.Contains(t, sql, "pending_auth_sessions_provider_type_check") + require.Contains(t, sql, "'github'") + require.Contains(t, sql, "'google'") +} diff --git a/deploy/Dockerfile b/deploy/Dockerfile index b0b6036c..a947158f 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26.2-alpine +ARG GOLANG_IMAGE=golang:1.26.3-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index dc5d3cc9..576fab0a 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -202,6 +202,14 @@ gateway: # # 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。 force_codex_cli: false + # Enable Codex image-generation bridge injection for /openai/v1/responses. + # 是否为 Codex /responses 请求自动注入 image_generation 工具与桥接指令。 + # + # Default false keeps text-only Codex requests text-only. Explicit client-provided + # image_generation tools are still forwarded when the group allows image generation. + # 默认 false:保持纯文本 Codex 请求不被改写;客户端显式提供 image_generation tool 时, + # 仍会在分组允许图片生成的情况下正常转发。 + codex_image_generation_bridge_enabled: false # Optional: template file used to build the final top-level Codex `instructions`. # 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。 # diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 8a127793..00ed4087 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -16,6 +16,8 @@ import type { TempUnschedulableStatus, AdminDataPayload, AdminDataImportResult, + CodexSessionImportRequest, + CodexSessionImportResult, CheckMixedChannelRequest, CheckMixedChannelResponse } from '@/types' @@ -547,6 +549,11 @@ export async function importData(payload: { return data } +export async function importCodexSession(payload: CodexSessionImportRequest): Promise { + const { data } = await apiClient.post('/admin/accounts/import/codex-session', payload) + return data +} + /** * Get Antigravity default model mapping from backend * @returns Default model mapping (from -> to) @@ -663,6 +670,7 @@ export const accountsAPI = { syncFromCrs, exportData, importData, + importCodexSession, getAntigravityDefaultModelMapping, batchClearError, batchRefresh, diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 639e3be2..b62e539d 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -31,6 +31,7 @@ import channelMonitorTemplateAPI from './channelMonitorTemplate' import adminPaymentAPI from './payment' import windsurfAPI from './windsurf' import affiliatesAPI from './affiliates' +import riskControlAPI from './riskControl' /** * Unified admin API object for convenient access @@ -63,7 +64,8 @@ export const adminAPI = { channelMonitorTemplate: channelMonitorTemplateAPI, payment: adminPaymentAPI, windsurf: windsurfAPI, - affiliates: affiliatesAPI + affiliates: affiliatesAPI, + riskControl: riskControlAPI } export { @@ -94,7 +96,8 @@ export { channelMonitorTemplateAPI, adminPaymentAPI, windsurfAPI, - affiliatesAPI + affiliatesAPI, + riskControlAPI } export default adminAPI @@ -104,3 +107,4 @@ export type { BalanceHistoryItem } from './users' export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' export type { BackupAgentHealth, DataManagementConfig } from './dataManagement' export type { TLSFingerprintProfile, CreateProfileRequest, UpdateProfileRequest } from './tlsFingerprintProfile' +export type { ContentModerationConfig, ContentModerationLog, ModerationMode } from './riskControl' diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts new file mode 100644 index 00000000..e63a53a2 --- /dev/null +++ b/frontend/src/api/admin/riskControl.ts @@ -0,0 +1,253 @@ +import { apiClient } from '../client' + +export type ModerationMode = 'off' | 'observe' | 'pre_block' + +export interface ContentModerationConfig { + enabled: boolean + mode: ModerationMode + base_url: string + model: string + api_key_configured: boolean + api_key_masked: string + api_key_count: number + api_key_masks: string[] + api_key_statuses: ContentModerationAPIKeyStatus[] + timeout_ms: number + sample_rate: number + all_groups: boolean + group_ids: number[] + record_non_hits: boolean + worker_count: number + queue_size: number + block_status: number + block_message: string + email_on_hit: boolean + auto_ban_enabled: boolean + ban_threshold: number + violation_window_hours: number + retry_count: number + hit_retention_days: number + non_hit_retention_days: number + pre_hash_check_enabled: boolean +} + +export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen' + +export interface ContentModerationAPIKeyStatus { + index: number + key_hash: string + masked: string + status: ContentModerationAPIKeyStatusValue + failure_count: number + success_count: number + last_error: string + last_checked_at?: string + frozen_until?: string + last_latency_ms: number + last_http_status: number + last_tested: boolean + configured: boolean +} + +export interface TestContentModerationAPIKeysPayload { + api_keys?: string[] + base_url?: string + model?: string + timeout_ms?: number + prompt?: string + images?: string[] +} + +export interface TestContentModerationAPIKeysResponse { + items: ContentModerationAPIKeyStatus[] + audit_result?: ContentModerationTestAuditResult + image_count: number +} + +export interface ContentModerationTestAuditResult { + flagged: boolean + highest_category: string + highest_score: number + composite_score: number + category_scores: Record + thresholds: Record +} + +export interface UpdateContentModerationConfig { + enabled?: boolean + mode?: ModerationMode + base_url?: string + model?: string + api_key?: string + api_keys?: string[] + api_keys_mode?: 'append' | 'replace' + delete_api_key_hashes?: string[] + clear_api_key?: boolean + timeout_ms?: number + sample_rate?: number + all_groups?: boolean + group_ids?: number[] + record_non_hits?: boolean + worker_count?: number + queue_size?: number + block_status?: number + block_message?: string + email_on_hit?: boolean + auto_ban_enabled?: boolean + ban_threshold?: number + violation_window_hours?: number + retry_count?: number + hit_retention_days?: number + non_hit_retention_days?: number + pre_hash_check_enabled?: boolean +} + +export interface ContentModerationRuntimeStatus { + enabled: boolean + risk_control_enabled: boolean + mode: ModerationMode + worker_count: number + max_workers: number + active_workers: number + idle_workers: number + queue_size: number + queue_length: number + queue_usage_percent: number + enqueued: number + dropped: number + processed: number + errors: number + api_key_statuses: ContentModerationAPIKeyStatus[] + flagged_hash_count: number + last_cleanup_at?: string + last_cleanup_deleted_hit: number + last_cleanup_deleted_non_hit: number +} + +export interface ContentModerationLog { + id: number + request_id: string + user_id: number | null + user_email: string + api_key_id: number | null + api_key_name: string + group_id: number | null + group_name: string + endpoint: string + provider: string + model: string + mode: string + action: string + flagged: boolean + highest_category: string + highest_score: number + category_scores: Record + threshold_snapshot: Record + input_excerpt: string + upstream_latency_ms: number | null + error: string + violation_count: number + auto_banned: boolean + email_sent: boolean + user_status: string + queue_delay_ms: number | null + created_at: string +} + +export interface ListContentModerationLogsParams { + page?: number + page_size?: number + result?: string + group_id?: number + endpoint?: string + search?: string + from?: string + to?: string +} + +export interface ContentModerationLogsResponse { + items: ContentModerationLog[] + total: number + page: number + page_size: number + pages: number +} + +export interface ContentModerationUnbanUserResponse { + user_id: number + status: string +} + +export interface DeleteFlaggedHashResponse { + input_hash: string + deleted: boolean +} + +export interface ClearFlaggedHashesResponse { + deleted: number +} + +export async function getConfig(): Promise { + const { data } = await apiClient.get('/admin/risk-control/config') + return data +} + +export async function updateConfig( + payload: UpdateContentModerationConfig +): Promise { + const { data } = await apiClient.put('/admin/risk-control/config', payload) + return data +} + +export async function getStatus(): Promise { + const { data } = await apiClient.get('/admin/risk-control/status') + return data +} + +export async function testAPIKeys( + payload: TestContentModerationAPIKeysPayload = {} +): Promise { + const { data } = await apiClient.post('/admin/risk-control/api-keys/test', payload) + return data +} + +export async function listLogs( + params: ListContentModerationLogsParams = {} +): Promise { + const { data } = await apiClient.get('/admin/risk-control/logs', { + params, + }) + return data +} + +export async function unbanUser(userID: number): Promise { + const { data } = await apiClient.post( + `/admin/risk-control/users/${userID}/unban` + ) + return data +} + +export async function deleteFlaggedHash(inputHash: string): Promise { + const { data } = await apiClient.delete('/admin/risk-control/hashes', { + data: { input_hash: inputHash }, + }) + return data +} + +export async function clearFlaggedHashes(): Promise { + const { data } = await apiClient.delete('/admin/risk-control/hashes/all') + return data +} + +export const riskControlAPI = { + getConfig, + updateConfig, + getStatus, + testAPIKeys, + listLogs, + unbanUser, + deleteFlaggedHash, + clearFlaggedHashes, +} + +export default riskControlAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 057a85e8..a863111d 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -4,14 +4,25 @@ */ import { apiClient } from "../client"; -import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types"; +import type { + CustomEndpoint, + CustomMenuItem, + LoginAgreementDocument, + NotifyEmailEntry, +} from "@/types"; export interface DefaultSubscriptionSetting { group_id: number; validity_days: number; } -export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat"; +export type AuthSourceType = + | "email" + | "linuxdo" + | "oidc" + | "wechat" + | "github" + | "google"; export interface AuthSourceDefaultsValue { balance: number; @@ -51,6 +62,8 @@ const AUTH_SOURCE_TYPES: AuthSourceType[] = [ "linuxdo", "oidc", "wechat", + "github", + "google", ]; const AUTH_SOURCE_DEFAULT_BALANCE = 0; const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; @@ -306,6 +319,10 @@ export interface SystemSettings { invitation_code_enabled: boolean; totp_enabled: boolean; // TOTP 双因素认证 totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 + login_agreement_enabled: boolean; + login_agreement_mode: "modal" | "checkbox" | string; + login_agreement_updated_at: string; + login_agreement_documents: LoginAgreementDocument[]; // Default settings default_balance: number; affiliate_rebate_rate: number; @@ -335,6 +352,16 @@ export interface SystemSettings { auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_wechat_grant_on_signup?: boolean; auth_source_default_wechat_grant_on_first_bind?: boolean; + auth_source_default_github_balance?: number; + auth_source_default_github_concurrency?: number; + auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_github_grant_on_signup?: boolean; + auth_source_default_github_grant_on_first_bind?: boolean; + auth_source_default_google_balance?: number; + auth_source_default_google_concurrency?: number; + auth_source_default_google_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_google_grant_on_signup?: boolean; + auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; // OEM settings site_name: string; @@ -410,6 +437,16 @@ export interface SystemSettings { oidc_connect_userinfo_email_path: string; oidc_connect_userinfo_id_path: string; oidc_connect_userinfo_username_path: string; + github_oauth_enabled: boolean; + github_oauth_client_id: string; + github_oauth_client_secret_configured: boolean; + github_oauth_redirect_url: string; + github_oauth_frontend_redirect_url: string; + google_oauth_enabled: boolean; + google_oauth_client_id: string; + google_oauth_client_secret_configured: boolean; + google_oauth_redirect_url: string; + google_oauth_frontend_redirect_url: string; // Model fallback configuration enable_model_fallback: boolean; @@ -445,6 +482,7 @@ export interface SystemSettings { // Payment configuration payment_enabled: boolean; + risk_control_enabled: boolean; payment_min_amount: number; payment_max_amount: number; payment_daily_limit: number; @@ -500,6 +538,10 @@ export interface UpdateSettingsRequest { frontend_url?: string; invitation_code_enabled?: boolean; totp_enabled?: boolean; // TOTP 双因素认证 + login_agreement_enabled?: boolean; + login_agreement_mode?: "modal" | "checkbox" | string; + login_agreement_updated_at?: string; + login_agreement_documents?: LoginAgreementDocument[]; default_balance?: number; affiliate_rebate_rate?: number; affiliate_rebate_freeze_hours?: number; @@ -528,6 +570,16 @@ export interface UpdateSettingsRequest { auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_wechat_grant_on_signup?: boolean; auth_source_default_wechat_grant_on_first_bind?: boolean; + auth_source_default_github_balance?: number; + auth_source_default_github_concurrency?: number; + auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_github_grant_on_signup?: boolean; + auth_source_default_github_grant_on_first_bind?: boolean; + auth_source_default_google_balance?: number; + auth_source_default_google_concurrency?: number; + auth_source_default_google_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_google_grant_on_signup?: boolean; + auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; site_name?: string; site_logo?: string; @@ -594,6 +646,16 @@ export interface UpdateSettingsRequest { oidc_connect_userinfo_email_path?: string; oidc_connect_userinfo_id_path?: string; oidc_connect_userinfo_username_path?: string; + github_oauth_enabled?: boolean; + github_oauth_client_id?: string; + github_oauth_client_secret?: string; + github_oauth_redirect_url?: string; + github_oauth_frontend_redirect_url?: string; + google_oauth_enabled?: boolean; + google_oauth_client_id?: string; + google_oauth_client_secret?: string; + google_oauth_redirect_url?: string; + google_oauth_frontend_redirect_url?: string; enable_model_fallback?: boolean; fallback_model_anthropic?: string; fallback_model_openai?: string; @@ -615,6 +677,7 @@ export interface UpdateSettingsRequest { enable_anthropic_cache_ttl_1h_injection?: boolean; // Payment configuration payment_enabled?: boolean; + risk_control_enabled?: boolean; payment_min_amount?: number; payment_max_amount?: number; payment_daily_limit?: number; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index ca4cf3a7..29c627f2 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2908,6 +2908,7 @@ :show-mobile-refresh-token-option="form.platform === 'openai'" :show-session-token-option="false" :show-access-token-option="false" + :show-codex-session-import-option="form.platform === 'openai'" :platform="form.platform" :show-project-id="geminiOAuthType === 'code_assist'" @generate-url="handleGenerateUrl" @@ -2915,6 +2916,7 @@ @validate-refresh-token="handleValidateRefreshToken" @validate-mobile-refresh-token="handleOpenAIValidateMobileRT" @validate-session-token="handleValidateSessionToken" + @import-codex-session="handleOpenAIImportCodexSession" /> @@ -3266,6 +3268,7 @@ import type { AccountType, CheckMixedChannelResponse, CreateAccountRequest, + CodexSessionImportMessage, OpenAICompactMode } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' @@ -3300,6 +3303,7 @@ interface OAuthFlowExposed { sessionKey: string refreshToken: string sessionToken: string + codexSession: string inputMethod: AuthInputMethod reset: () => void } @@ -4857,6 +4861,113 @@ const handleOpenAIExchange = async (authCode: string) => { // OpenAI Mobile RT client_id const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK' +const buildOpenAICodexImportCredentialExtras = (): Record | null => { + const credentials: Record = {} + if (!isOpenAIModelRestrictionDisabled.value) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + } + + const compactModelMapping = buildOpenAICompactModelMapping() + if (compactModelMapping) { + credentials.compact_model_mapping = compactModelMapping + } + + if (!applyTempUnschedConfig(credentials)) { + return null + } + return credentials +} + +const formatCodexImportMessages = (messages?: CodexSessionImportMessage[]) => { + return (messages || []) + .map((item) => { + const name = item.name ? ` ${item.name}` : '' + return `#${item.index}${name}: ${item.message}` + }) + .join('\n') +} + +const handleOpenAIImportCodexSession = async (content: string) => { + const oauthClient = openaiOAuth + const trimmed = content.trim() + if (!trimmed) { + oauthClient.error.value = t('admin.accounts.oauth.openai.codexSessionEmpty') + return + } + + const credentialExtras = buildOpenAICodexImportCredentialExtras() + if (credentialExtras === null) { + return + } + + oauthClient.loading.value = true + oauthClient.error.value = '' + + try { + const extra = buildOpenAIExtra() + const result = await adminAPI.accounts.importCodexSession({ + content: trimmed, + name: form.name, + notes: form.notes || null, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value, + credential_extras: Object.keys(credentialExtras).length > 0 ? credentialExtras : undefined, + extra, + update_existing: true + }) + + const successCount = result.created + result.updated + const params = { + created: result.created, + updated: result.updated, + skipped: result.skipped, + failed: result.failed + } + + if (successCount > 0 && result.failed === 0) { + appStore.showSuccess(t('admin.accounts.oauth.openai.codexSessionImportSuccess', params)) + emit('created') + handleClose() + return + } + + const errorText = formatCodexImportMessages(result.errors) + const warningText = formatCodexImportMessages(result.warnings) + oauthClient.error.value = [errorText, warningText].filter(Boolean).join('\n') + + if (result.failed === 0) { + appStore.showWarning(t('admin.accounts.oauth.openai.codexSessionImportSuccess', params)) + return + } + + if (successCount > 0) { + appStore.showWarning(t('admin.accounts.oauth.openai.codexSessionImportPartial', params)) + emit('created') + return + } + + appStore.showError(t('admin.accounts.oauth.openai.codexSessionImportFailed')) + } catch (error: any) { + oauthClient.error.value = + error.response?.data?.detail || + error.response?.data?.message || + error.message || + t('admin.accounts.oauth.openai.codexSessionImportFailed') + appStore.showError(oauthClient.error.value) + } finally { + oauthClient.loading.value = false + } +} + // OpenAI RT 批量验证和创建(共享逻辑) const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => { const oauthClient = openaiOAuth diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index b117bff3..3fd82f31 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1349,6 +1349,66 @@ + +
+
+
+
+ +
+
+
+ + + {{ codexImageGenerationBridgeBadgeLabel }} + +
+

+ {{ t('admin.accounts.openai.codexImageGenerationBridgeDesc') }} +

+
+
+
+
+ +
+
+
+
+
('auto') const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) +type CodexImageGenerationBridgeMode = 'inherit' | 'enabled' | 'disabled' +const codexImageGenerationBridgeMode = ref('inherit') const anthropicPassthroughEnabled = ref(false) const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) @@ -2358,6 +2420,47 @@ const openaiResponsesWebSocketV2Mode = computed({ const openAIWSModeConcurrencyHintKey = computed(() => resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) ) +const codexImageGenerationBridgeOptions = computed>(() => [ + { + value: 'inherit', + label: t('admin.accounts.openai.codexImageGenerationBridgeInherit'), + description: t('admin.accounts.openai.codexImageGenerationBridgeInheritDesc') + }, + { + value: 'enabled', + label: t('admin.accounts.openai.codexImageGenerationBridgeEnabled'), + description: t('admin.accounts.openai.codexImageGenerationBridgeEnabledDesc') + }, + { + value: 'disabled', + label: t('admin.accounts.openai.codexImageGenerationBridgeDisabled'), + description: t('admin.accounts.openai.codexImageGenerationBridgeDisabledDesc') + } +]) +const codexImageGenerationBridgeBadgeLabel = computed(() => { + switch (codexImageGenerationBridgeMode.value) { + case 'enabled': + return t('admin.accounts.openai.codexImageGenerationBridgeBadgeEnabled') + case 'disabled': + return t('admin.accounts.openai.codexImageGenerationBridgeBadgeDisabled') + default: + return t('admin.accounts.openai.codexImageGenerationBridgeBadgeInherit') + } +}) +const codexImageGenerationBridgeBadgeClass = computed(() => { + switch (codexImageGenerationBridgeMode.value) { + case 'enabled': + return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/40 dark:text-emerald-300' + case 'disabled': + return 'bg-rose-100 text-rose-700 dark:bg-rose-900/40 dark:text-rose-300' + default: + return 'bg-slate-100 text-slate-600 dark:bg-dark-600 dark:text-slate-300' + } +}) const openAICompactModeOptions = computed(() => [ { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') }, { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') }, @@ -2377,7 +2480,7 @@ const openAICompactStatusKey = computed(() => { ? 'admin.accounts.openai.compactSupported' : 'admin.accounts.openai.compactUnsupported' } - return 'admin.accounts.openai.compactUnknown' + return 'admin.accounts.openai.compactAuto' }) // Computed: current preset mappings based on platform @@ -2516,11 +2619,20 @@ const syncFormFromAccount = (newAccount: Account | null) => { openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false + codexImageGenerationBridgeMode.value = 'inherit' anthropicPassthroughEnabled.value = false webSearchEmulationMode.value = 'default' if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto' + const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean' + ? extra.codex_image_generation_bridge + : extra?.codex_image_generation_bridge_enabled + if (codexImageGenerationBridgeValue === true) { + codexImageGenerationBridgeMode.value = 'enabled' + } else if (codexImageGenerationBridgeValue === false) { + codexImageGenerationBridgeMode.value = 'disabled' + } openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { modeKey: 'openai_oauth_responses_websockets_v2_mode', enabledKey: 'openai_oauth_responses_websockets_v2_enabled', @@ -3652,6 +3764,13 @@ const handleSubmit = async () => { newExtra.openai_compact_mode = openAICompactMode.value } + delete newExtra.codex_image_generation_bridge_enabled + if (codexImageGenerationBridgeMode.value === 'inherit') { + delete newExtra.codex_image_generation_bridge + } else { + newExtra.codex_image_generation_bridge = codexImageGenerationBridgeMode.value === 'enabled' + } + if (props.account.type === 'oauth') { if (codexCLIOnlyEnabled.value) { newExtra.codex_cli_only = true diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 08c67494..9526e878 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -81,6 +81,17 @@ t('admin.accounts.oauth.openai.accessTokenAuth', '手动输入 AT') }} +
@@ -168,6 +179,85 @@ + +
+
+

+ {{ t('admin.accounts.oauth.openai.codexSessionDesc') }} +

+ +
+ + +

+ {{ t('admin.accounts.oauth.openai.codexSessionHint') }} +

+
+ +
+

+ {{ error }} +

+
+ + +
+
+
(), { showMobileRefreshTokenOption: false, showSessionTokenOption: false, showAccessTokenOption: false, + showCodexSessionImportOption: false, platform: 'anthropic', showProjectId: true }) @@ -591,6 +683,7 @@ const emit = defineEmits<{ 'validate-mobile-refresh-token': [refreshToken: string] 'validate-session-token': [sessionToken: string] 'import-access-token': [accessToken: string] + 'import-codex-session': [content: string] 'update:inputMethod': [method: AuthInputMethod] }>() @@ -630,12 +723,13 @@ const authCodeInput = ref('') const sessionKeyInput = ref('') const refreshTokenInput = ref('') const sessionTokenInput = ref('') +const codexSessionInput = ref('') const showHelpDialog = ref(false) const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption || props.showCodexSessionImportOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -656,6 +750,16 @@ const parsedRefreshTokenCount = computed(() => { .filter((rt) => rt).length }) +const parsedCodexSessionCount = computed(() => { + const trimmed = codexSessionInput.value.trim() + if (!trimmed) return 0 + if (trimmed.startsWith('{') || trimmed.startsWith('[')) return 1 + return trimmed + .split('\n') + .map((item) => item.trim()) + .filter((item) => item).length +}) + // Watchers watch(inputMethod, (newVal) => { emit('update:inputMethod', newVal) @@ -727,6 +831,12 @@ const handleValidateRefreshToken = () => { } } +const handleImportCodexSession = () => { + if (codexSessionInput.value.trim()) { + emit('import-codex-session', codexSessionInput.value.trim()) + } +} + // Expose methods and state defineExpose({ authCode: authCodeInput, @@ -735,6 +845,7 @@ defineExpose({ sessionKey: sessionKeyInput, refreshToken: refreshTokenInput, sessionToken: sessionTokenInput, + codexSession: codexSessionInput, inputMethod, reset: () => { authCodeInput.value = '' @@ -743,6 +854,7 @@ defineExpose({ sessionKeyInput.value = '' refreshTokenInput.value = '' sessionTokenInput.value = '' + codexSessionInput.value = '' inputMethod.value = 'manual' showHelpDialog.value = false } diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts index c4e2a9bc..04486154 100644 --- a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts @@ -216,4 +216,25 @@ describe('EditAccountModal', () => { 'gpt-5.4': 'gpt-5.4-openai-compact' }) }) + + it('submits account-level Codex image generation bridge override', async () => { + const account = buildAccount() + account.extra = { + codex_image_generation_bridge: false, + codex_image_generation_bridge_enabled: true + } + updateAccountMock.mockReset() + checkMixedChannelRiskMock.mockReset() + checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false }) + updateAccountMock.mockResolvedValue(account) + + const wrapper = mountModal(account) + + await wrapper.get('button[data-testid="codex-image-bridge-enabled"]').trigger('click') + await wrapper.get('form#edit-account-form').trigger('submit.prevent') + + expect(updateAccountMock).toHaveBeenCalledTimes(1) + expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.codex_image_generation_bridge).toBe(true) + expect(updateAccountMock.mock.calls[0]?.[1]?.extra).not.toHaveProperty('codex_image_generation_bridge_enabled') + }) }) diff --git a/frontend/src/components/admin/account/AccountTableActions.vue b/frontend/src/components/admin/account/AccountTableActions.vue index ee521f83..6874625b 100644 --- a/frontend/src/components/admin/account/AccountTableActions.vue +++ b/frontend/src/components/admin/account/AccountTableActions.vue @@ -5,7 +5,6 @@ - @@ -17,7 +16,7 @@ import { useI18n } from 'vue-i18n' import Icon from '@/components/icons/Icon.vue' defineProps(['loading']) -defineEmits(['refresh', 'sync', 'create']) +defineEmits(['refresh', 'create']) const { t } = useI18n() diff --git a/frontend/src/components/auth/EmailOAuthButtons.vue b/frontend/src/components/auth/EmailOAuthButtons.vue new file mode 100644 index 00000000..b5d874a5 --- /dev/null +++ b/frontend/src/components/auth/EmailOAuthButtons.vue @@ -0,0 +1,87 @@ + + + diff --git a/frontend/src/components/auth/GitHubMark.vue b/frontend/src/components/auth/GitHubMark.vue new file mode 100644 index 00000000..a790e622 --- /dev/null +++ b/frontend/src/components/auth/GitHubMark.vue @@ -0,0 +1,7 @@ + diff --git a/frontend/src/components/auth/GoogleMark.vue b/frontend/src/components/auth/GoogleMark.vue new file mode 100644 index 00000000..a848a811 --- /dev/null +++ b/frontend/src/components/auth/GoogleMark.vue @@ -0,0 +1,8 @@ + diff --git a/frontend/src/components/auth/LoginAgreementPrompt.vue b/frontend/src/components/auth/LoginAgreementPrompt.vue new file mode 100644 index 00000000..dd71cbdc --- /dev/null +++ b/frontend/src/components/auth/LoginAgreementPrompt.vue @@ -0,0 +1,221 @@ + + + + + diff --git a/frontend/src/components/auth/__tests__/EmailOAuthButtons.spec.ts b/frontend/src/components/auth/__tests__/EmailOAuthButtons.spec.ts new file mode 100644 index 00000000..d8517808 --- /dev/null +++ b/frontend/src/components/auth/__tests__/EmailOAuthButtons.spec.ts @@ -0,0 +1,103 @@ +import { mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import EmailOAuthButtons from '@/components/auth/EmailOAuthButtons.vue' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/register?aff=AFF123' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.emailOAuth.signIn') { + return `使用 ${params?.providerName ?? ''} 登录` + } + return key + }, + }), +})) + +describe('EmailOAuthButtons', () => { + beforeEach(() => { + routeState.query = { redirect: '/billing?plan=pro', aff: 'AFF123' } + locationState.current = { href: 'http://localhost/register?aff=AFF123' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + window.localStorage.clear() + window.sessionStorage.clear() + }) + + it('passes the affiliate code to the email oauth start URL', async () => { + const wrapper = mount(EmailOAuthButtons, { + props: { + githubEnabled: true, + googleEnabled: false, + }, + global: { + stubs: { + GitHubMark: true, + GoogleMark: true, + }, + }, + }) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toBe( + '/api/v1/auth/oauth/github/start?redirect=%2Fbilling%3Fplan%3Dpro&aff_code=AFF123' + ) + expect(window.sessionStorage.getItem('oauth_aff_code')).toBe('AFF123') + expect(window.sessionStorage.getItem('email_oauth_pending_provider')).toBe('github') + }) + + it('uses a full-width descriptive button when only GitHub is enabled', () => { + const wrapper = mount(EmailOAuthButtons, { + props: { + githubEnabled: true, + googleEnabled: false, + }, + global: { + stubs: { + GitHubMark: true, + GoogleMark: true, + }, + }, + }) + + expect(wrapper.find('.grid').classes()).not.toContain('sm:grid-cols-2') + expect(wrapper.get('button').text()).toContain('使用 GitHub 登录') + }) + + it('uses compact labels and two columns when GitHub and Google are both enabled', () => { + const wrapper = mount(EmailOAuthButtons, { + props: { + githubEnabled: true, + googleEnabled: true, + }, + global: { + stubs: { + GitHubMark: true, + GoogleMark: true, + }, + }, + }) + + expect(wrapper.find('.grid').classes()).toContain('sm:grid-cols-2') + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(2) + expect(buttons[0].text()).toContain('GitHub') + expect(buttons[0].text()).not.toContain('使用 GitHub 登录') + expect(buttons[1].text()).toContain('Google') + expect(buttons[1].text()).not.toContain('使用 Google 登录') + }) +}) diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 47d8a25f..bede24e9 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -593,6 +593,21 @@ const SignalIcon = { ) } +const ShieldIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9 12.75L11.25 15 15 9.75m-3-7.036A11.959 11.959 0 013.598 6 11.99 11.99 0 003 9.749c0 5.592 3.824 10.29 9 11.623 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751h-.152c-3.196 0-6.1-1.248-8.25-3.285z' + }) + ] + ) +} + const PriceTagIcon = { render: () => h( @@ -635,6 +650,7 @@ const flagChannelMonitor = makeSidebarFlag(FeatureFlags.channelMonitor) const flagPayment = makeSidebarFlag(FeatureFlags.payment) const flagAvailableChannels = makeSidebarFlag(FeatureFlags.availableChannels) const flagAffiliate = makeSidebarFlag(FeatureFlags.affiliate) +const flagRiskControl = makeSidebarFlag(FeatureFlags.riskControl) const flagOpsMonitoring = () => adminSettingsStore.opsMonitoringEnabled const flagAdminPayment = () => adminSettingsStore.paymentEnabled @@ -720,6 +736,7 @@ const adminNavItems = computed((): NavItem[] => { { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, + { path: '/admin/risk-control', label: t('nav.riskControl'), icon: ShieldIcon, hideInSimpleMode: true, featureFlag: flagRiskControl }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, { diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index 37ee8a55..2c190715 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -263,7 +263,9 @@ const providerLabels = computed>(() => ({ email: t('profile.authBindings.providers.email'), linuxdo: t('profile.authBindings.providers.linuxdo'), oidc: t('profile.authBindings.providers.oidc', { providerName: props.oidcProviderName }), - wechat: t('profile.authBindings.providers.wechat') + wechat: t('profile.authBindings.providers.wechat'), + github: 'GitHub', + google: 'Google' })) function formatCurrency(value: number): string { @@ -272,7 +274,13 @@ function formatCurrency(value: number): string { function normalizeProvider(value: string): UserAuthProvider | null { const normalized = value.trim().toLowerCase() - if (normalized === 'email' || normalized === 'linuxdo' || normalized === 'wechat') { + if ( + normalized === 'email' || + normalized === 'linuxdo' || + normalized === 'wechat' || + normalized === 'github' || + normalized === 'google' + ) { return normalized } if (normalized === 'oidc' || normalized.startsWith('oidc:') || normalized.startsWith('oidc/')) { diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index d35e3b12..29ec513e 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -13,6 +13,7 @@ describe('useModelWhitelist', () => { expect(models).toContain('gpt-5.4') expect(models).toContain('gpt-5.4-mini') expect(models).toContain('gpt-5.4-2026-03-05') + expect(models).toContain('codex-auto-review') }) it('openai 模型列表不再暴露已下线的 ChatGPT 登录 Codex 模型', () => { diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index 564e7d95..ab4c640a 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token' | 'codex_session' export interface OAuthState { authUrl: string diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 07c98ed0..7b474a4e 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -11,8 +11,8 @@ const openaiModels = [ 'gpt-5.5', // GPT-5.4 系列 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05', - // GPT-5.3 系列 - 'gpt-5.3-codex', 'gpt-5.3-codex-spark', + // GPT-5.3 / Codex 系列 + 'gpt-5.3-codex', 'gpt-5.3-codex-spark', 'codex-auto-review', 'gpt-4o-audio-preview', 'gpt-4o-realtime-preview', // GPT Image 系列 'gpt-image-1', 'gpt-image-1.5', 'gpt-image-2' diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 2db497f2..02a4974f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -384,6 +384,7 @@ export default { channelPricing: 'Channel Pricing', channelMonitor: 'Channel Monitor', channelStatus: 'Channel Status', + riskControl: 'Risk Control', }, // Auth @@ -412,6 +413,9 @@ export default { passwordRequired: 'Password is required', passwordMinLength: 'Password must be at least 6 characters', loginFailed: 'Login failed. Please check your credentials and try again.', + errors: { + USER_NOT_ACTIVE: 'Account has been disabled.', + }, registrationFailed: 'Registration failed. Please try again.', emailSuffixNotAllowed: 'This email domain is not allowed for registration.', emailSuffixNotAllowedWithAllowed: @@ -474,6 +478,9 @@ export default { completing: 'Completing registration…', completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.' }, + emailOAuth: { + signIn: 'Continue with {providerName}' + }, oidc: { signIn: 'Continue with {providerName}', callbackTitle: 'Signing you in with {providerName}', @@ -533,6 +540,8 @@ export default { oauth: { callbackTitle: 'OAuth Callback', callbackHint: 'Copy the code and state back to the admin authorization flow when needed.', + invalidCallbackTitle: 'Invalid sign-in callback', + invalidCallbackHint: 'This page does not contain a valid authorization result. Return to the login page and start quick sign-in again.', code: 'Code', state: 'State', fullUrl: 'Full URL' @@ -2286,6 +2295,8 @@ export default { webSearchEmulation: 'Web Search Emulation', webSearchEmulationHint: '⚠️ When enabled, all accounts in this channel\'s Anthropic groups will intercept web_search requests. Use with caution.', webSearchEmulationGlobalDisabled: 'Please enable the global switch first in Settings → Gateway → Web Search Emulation', + codexImageGenerationBridge: 'Codex Image Generation Bridge', + codexImageGenerationBridgeHint: 'When enabled, Codex /responses text requests in OpenAI groups may be automatically given the image_generation tool. Keep off unless the routed accounts support image generation.', basicSettings: 'Basic Settings', addPlatform: 'Add Platform', noPlatforms: 'Click "Add Platform" to start configuring the channel', @@ -2308,6 +2319,216 @@ export default { } }, + riskControl: { + title: 'Risk Control', + description: 'Configure content moderation and review audit records', + loadFailed: 'Failed to load risk control', + saveFailed: 'Failed to save content moderation config', + logsFailed: 'Failed to load audit records', + saved: 'Content moderation config saved', + refresh: 'Refresh', + config: 'Content Moderation Config', + configHint: 'Use OpenAI Moderations to score request content and handle threshold hits by mode.', + openSettings: 'Moderation Settings', + settingsTitle: 'Content Moderation Settings', + refreshStatus: 'Refresh Status', + records: 'Audit Records', + recordsHint: 'Shows hits, blocks, errors, and sampled records.', + saveConfig: 'Save Moderation Config', + statusFailed: 'Failed to load runtime status', + enabled: 'Enable Content Moderation', + enabledHint: 'When off, gateway requests are not moderated even if the menu is enabled.', + mode: 'Global Mode', + modePreBlock: 'Pre-Block', + modePreBlockDesc: 'Synchronously reviews the latest user input before every request and rejects hits immediately.', + modeObserve: 'Observe Only', + modeObserveDesc: 'Requests pass through while the latest user input is queued for async review; hits are recorded, notified, and counted.', + modeOff: 'Off', + modeOffDesc: 'Content moderation is disabled and no audit records are written.', + baseUrl: 'OpenAI Base URL', + model: 'Model', + apiKey: 'OpenAI API Key', + apiKeys: 'OpenAI API Keys', + apiKeyCount: '{count} keys', + apiKeyPlaceholder: 'Enter API Key', + apiKeysPlaceholder: 'Add API Keys, one per line. They will be appended on save.', + apiKeysPlaceholderReplace: 'Replace API Keys, one per line. Stored keys will be replaced on save.', + apiKeysPlaceholderKeep: 'Add API Keys, one per line. They will be appended on save.', + apiKeysHint: '{count} keys are currently stored. This input only adds keys; save appends and de-duplicates them.', + apiKeysWriteMode: 'Write mode', + apiKeysModeAppend: 'Add', + apiKeysModeReplace: 'Replace', + apiKeysModeAppendHint: 'Default: save appends input keys and keeps stored keys.', + apiKeysModeReplaceHint: 'Replace mode: save replaces all stored keys with input keys.', + apiKeysReplaceWarning: 'Replace mode', + apiKeysReplaceNoInput: 'Replace mode requires at least 1 API Key', + apiKeyPlaceholderKeep: 'Leave empty to keep current key', + apiKeyWillClear: 'Configured key will be cleared on save', + apiKeyConfigured: 'Configured', + apiKeyTemporary: 'Pending', + apiKeyPendingDelete: 'Pending delete', + apiKeyPendingDeleteCount: '{count} keys pending deletion', + deleteApiKey: 'Delete this key', + undoDeleteApiKey: 'Undo delete', + inputApiKeyCount: '{count} keys in input', + storedApiKeyCount: '{count} stored keys', + testInputApiKeys: 'Test input keys', + testStoredApiKeys: 'Test stored keys', + testContentWithStoredApiKey: 'Test content with stored key', + testingApiKeys: 'Testing', + apiKeyTestNoInput: 'Enter OpenAI API Keys to test first', + apiKeyTestDone: 'Key test completed for {count} keys', + apiKeyTestFailed: 'Failed to test OpenAI API Keys', + apiKeyHealth: 'Key Availability', + apiKeyFreezeRule: '400 does not freeze; 401/403 freeze for 10 minutes; 429/529 freeze for 1 minute; other HTTP errors freeze for 10 seconds.', + apiKeyRows: '{count} keys', + apiKeyRowsCollapsed: '{count} keys hidden', + apiKeyRowsExpanded: 'Showing all {count} keys', + expandApiKeyRows: 'Expand', + collapseApiKeyRows: 'Collapse', + apiKeyHealthEmpty: 'No key status yet', + apiKeyHealthEmptyHint: 'Save keys or test input keys to see availability.', + apiKeyStatusOk: 'Available', + apiKeyStatusError: 'Error', + apiKeyStatusFrozen: 'Frozen', + apiKeyStatusUnknown: 'Untested', + apiKeyFailureCount: '{count} failures', + apiKeyLatency: '{ms} ms', + apiKeyHTTPStatus: 'HTTP {status}', + apiKeyFrozenUntil: 'Frozen until {time}', + apiKeyLastChecked: 'Checked at {time}', + apiKeyNotTested: 'Not tested', + auditTestInput: 'Audit Test Input', + auditTestInputHint: 'Enter a prompt and upload or paste images; images are sent as base64 and are not stored.', + auditTestPromptPlaceholder: 'Enter a user prompt to test; leave empty to only test key availability.', + auditTestImages: 'Test Images', + auditTestImagesHint: 'Upload, drag, or paste images. Up to 1 image, 8MB each.', + addAuditTestImage: 'Add image', + clearAuditTest: 'Clear test', + auditTestImageLimit: 'You can add up to {count} test images', + auditTestImageTooLarge: 'Each test image must be 8MB or smaller', + auditTestImageReadFailed: 'Failed to read test image', + auditTestResult: 'Audit Test Result', + auditTestHighest: 'Top category {category}, score {score}', + auditTestComposite: 'Composite score', + auditTestFlagged: 'Threshold hit', + auditTestPassed: 'Pass', + notConfigured: 'Not configured', + clearApiKey: 'Clear stored key', + keepApiKey: 'Keep stored key', + timeoutMs: 'HTTP Timeout (ms)', + retryCount: 'Retry Count', + sampleRate: 'Sample Rate', + recordNonHits: 'Record Non-Hits', + recordNonHitsHint: 'When enabled, sampled non-hit request summaries are redacted before storage.', + preHashCheck: 'Enable Pre-Hash Check', + preHashCheckHint: 'Hashes from async hits are blocked before moderation; this does not send email or increment ban counters.', + flaggedHashCount: 'Current hash collection size: {count}', + flaggedHashHint: 'Hashes are stored permanently in Redis; paste a full 64-character hash to remove a false block, or clear all stored hashes.', + flaggedHashPlaceholder: 'Paste full 64-character input hash', + deleteFlaggedHash: 'Delete hash', + clearFlaggedHashes: 'Clear all', + clearFlaggedHashesConfirm: 'Clear all risk input hashes? This does not delete audit records, but removes all historical hash blocks.', + flaggedHashDeleted: 'Risk hash deleted', + flaggedHashNotFound: 'Risk hash not found', + flaggedHashDeleteFailed: 'Failed to delete risk hash', + flaggedHashesCleared: 'Cleared {count} risk hashes', + flaggedHashesClearFailed: 'Failed to clear risk hashes', + workerCount: 'Worker Count', + queueSize: 'Async Queue Size', + blockStatus: 'Block HTTP Status', + blockMessage: 'Custom Block Message', + emailOnHit: 'Email on Hit', + emailOnHitHint: 'When enabled, send a risk-control email on every hit; auto-ban notices are always sent.', + autoBan: 'Auto Ban User', + autoBanHint: 'Disable the user, invalidate auth cache, and send a ban notice after the hit threshold is reached.', + banThreshold: 'Ban Threshold', + violationWindowHours: 'Count Window (hours)', + hitRetentionDays: 'Hit Record Retention (days)', + nonHitRetentionDays: 'Non-Hit Record Retention (days, max 3)', + violationCount: '{count} hits', + emailSent: 'Email sent', + emailNotSent: 'No email', + autoBanned: 'Banned', + unbanUser: 'Unban', + unbanSuccess: 'User has been unbanned', + unbanFailed: 'Failed to unban user', + inputDetailTitle: 'Input Summary Detail', + inputDetailContent: 'Full Content', + queueDelay: 'Queued {ms} ms', + allGroups: 'All Groups', + allGroupsHint: 'Auditing all groups', + selectedGroupsHint: 'Auditing selected groups', + groupScope: 'Audit Groups', + groupScopeHint: 'Switch on for all groups, or turn off to choose specific groups.', + selectedGroups: 'Selected Groups', + searchGroups: 'Search group name or platform', + noGroups: 'No groups available', + emptyLogs: 'No audit records', + workerStatus: 'Worker Runtime', + workerStatusHint: 'Queue and worker pool status for asynchronous observation tasks.', + workerPool: 'Worker Pool', + workerPoolMeta: '{active} processing, {idle} idle and ready, {total} total', + queueUsage: 'Queue Usage', + activeWorkers: 'Processing', + idleWorkers: 'Idle Ready', + workerActive: 'Processing an asynchronous audit task', + workerIdle: 'Started, idle and ready', + workerDisabled: 'Risk control or content audit is disabled', + processed: 'Processed', + droppedErrors: 'Dropped / Errors', + autoRefresh: 'Auto refresh every 15s', + lastCleanup: 'Last cleanup: {time}', + cleanupStats: 'Last cleanup deleted {hit} hits and {nonHit} non-hits', + riskSwitchOff: 'System switch off', + tabs: { + basic: 'Basic', + scope: 'Scope', + runtime: 'Runtime', + response: 'Hit Notice', + retention: 'Retention', + }, + overview: { + status: 'Status', + enabled: 'Enabled', + disabled: 'Disabled', + apiKey: 'API Key', + groupScope: 'Scope', + logs: 'Audit Records', + currentFilter: 'Current filter', + }, + filters: { + search: 'Search user/key/summary', + from: 'From', + to: 'To', + allGroups: 'All Groups', + allEndpoints: 'All Endpoints', + }, + table: { + time: 'Time', + group: 'Group', + user: 'User', + apiKey: 'API Key', + endpoint: 'Endpoint', + result: 'Result', + highest: 'Highest', + actionMeta: 'Action', + latency: 'Latency', + input: 'Input Summary', + }, + result: { + all: 'All Results', + hit: 'Hit', + blocked: 'Blocked', + pass: 'Pass', + error: 'Error', + }, + action: { + block: 'Blocked', + error: 'Error', + }, + }, + // Channel Monitor channelMonitor: { title: 'Channel Monitor', @@ -2559,6 +2780,11 @@ export default { dataExportSelected: 'Export Selected', dataExportIncludeProxies: 'Include proxies linked to the exported accounts', dataImport: 'Import', + moreActions: 'More Actions', + dataActions: 'Data', + toolActions: 'Tools', + viewColumns: 'Columns', + selectedCount: '{count} selected', dataExportConfirmMessage: 'The exported data contains sensitive account and proxy information. Store it securely.', dataExportConfirm: 'Confirm Export', dataExported: 'Data exported successfully', @@ -2944,6 +3170,18 @@ export default { codexCLIOnly: 'Codex official clients only', codexCLIOnlyDesc: 'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.', + codexImageGenerationBridge: 'Codex image-generation bridge', + codexImageGenerationBridgeDesc: + 'Account policy takes precedence over channel and global settings. Only controls whether Codex requests through the /responses text endpoint receive the image_generation tool; standalone image-generation endpoints are unaffected.', + codexImageGenerationBridgeInherit: 'Follow channel', + codexImageGenerationBridgeInheritDesc: 'Do not write an account override; use the channel or global policy.', + codexImageGenerationBridgeEnabled: 'Force on', + codexImageGenerationBridgeEnabledDesc: 'Allow image tool injection for Codex /responses requests.', + codexImageGenerationBridgeDisabled: 'Force off', + codexImageGenerationBridgeDisabledDesc: 'Block image tool injection for Codex /responses requests.', + codexImageGenerationBridgeBadgeInherit: 'Channel policy', + codexImageGenerationBridgeBadgeEnabled: 'Account on', + codexImageGenerationBridgeBadgeDisabled: 'Account off', compactMode: 'Compact mode', compactModeDesc: 'Controls how this account participates in /responses/compact routing. Auto follows probe results, Force On always allows, Force Off always excludes.', @@ -2955,7 +3193,8 @@ export default { 'Only applies to /responses/compact. Use this when the upstream compact endpoint requires a special compact model.', compactSupported: 'Compact supported', compactUnsupported: 'Compact unsupported', - compactUnknown: 'Compact unknown', + compactAuto: 'Compact Auto', + compactUnknown: 'Compact Auto', compactLastChecked: 'Last compact probe', testMode: 'Test mode', testModeDefault: 'Default request', @@ -2989,7 +3228,7 @@ export default { targetNoWildcard: 'Target model cannot contain wildcard *', searchModels: 'Search models...', noMatchingModels: 'No matching models', - fillRelatedModels: 'Fill related models', + fillRelatedModels: 'Sync latest supported models', clearAllModels: 'Clear all models', customModelName: 'Custom model name', enterCustomModelName: 'Enter custom model name', @@ -3254,6 +3493,16 @@ export default { refreshTokenAuth: 'Manual RT Input', refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', + codexSessionAuth: 'Codex JSON / AT Batch Input', + codexSessionDesc: 'Paste Codex JSON or an accessToken. Accounts use the step 1 settings.', + codexSessionInputLabel: 'Codex JSON or accessToken', + codexSessionPlaceholder: 'Multiple lines supported, one token or JSON per line', + codexSessionHint: 'sessionToken will not be saved as refresh_token. Without refresh_token, the account expires with the accessToken expiry; import is rejected if the expiry cannot be parsed and step 1 has no expiration.', + codexSessionImportAndCreate: 'Import & Create Account', + codexSessionEmpty: 'Please enter Codex JSON or accessToken', + codexSessionImportFailed: 'Failed to import Codex account', + codexSessionImportSuccess: 'Import completed: created {created}, updated {updated}, skipped {skipped}', + codexSessionImportPartial: 'Partial success: created {created}, updated {updated}, skipped {skipped}, failed {failed}', sessionTokenAuth: 'Manual ST Input', sessionTokenDesc: 'Enter your existing Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', sessionTokenPlaceholder: 'Paste your Session Token...\nSupports multiple, one per line', @@ -4903,6 +5152,7 @@ export default { description: 'Manage registration, email verification, default values, and SMTP settings', tabs: { general: 'General', + agreement: 'Agreement', features: 'Feature Switches', security: 'Security', users: 'Users', @@ -4928,6 +5178,13 @@ export default { enabled: 'Enable Available Channels', enabledHint: 'When off, the sidebar entry is hidden and the endpoint returns an empty list.', }, + riskControl: { + title: 'Risk Control', + description: 'Enable the content moderation menu and gateway audit entry point. Disabled by default.', + configureLink: 'Configure content moderation in Risk Control', + enabled: 'Enable Risk Control', + enabledHint: 'When off, the admin sidebar entry is hidden and gateway moderation is skipped.', + }, affiliate: { title: 'Affiliate (Invite Rebate)', description: 'Existing users invite new ones; the inviter earns a percentage rebate on the invitee’s recharges. Disabled by default.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 8e622a39..d486d872 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -384,6 +384,7 @@ export default { channelPricing: '渠道定价', channelMonitor: '渠道监控', channelStatus: '渠道状态', + riskControl: '风控中心', }, // Auth @@ -412,6 +413,9 @@ export default { passwordRequired: '请输入密码', passwordMinLength: '密码至少需要 6 个字符', loginFailed: '登录失败,请检查您的凭据后重试。', + errors: { + USER_NOT_ACTIVE: '账号已被禁用', + }, registrationFailed: '注册失败,请重试。', emailSuffixNotAllowed: '该邮箱域名不在允许注册范围内。', emailSuffixNotAllowedWithAllowed: '该邮箱域名不被允许。可用域名:{suffixes}', @@ -473,6 +477,9 @@ export default { completing: '正在完成注册...', completeRegistrationFailed: '注册失败,请检查邀请码后重试。' }, + emailOAuth: { + signIn: '使用 {providerName} 登录' + }, oidc: { signIn: '使用 {providerName} 登录', callbackTitle: '正在完成 {providerName} 登录', @@ -531,6 +538,8 @@ export default { oauth: { callbackTitle: 'OAuth 回调', callbackHint: '按需将授权码和状态值复制回后台授权流程。', + invalidCallbackTitle: '无效的登录回调', + invalidCallbackHint: '当前页面缺少有效的授权结果,请返回登录页重新发起快捷登录。', code: '授权码', state: '状态', fullUrl: '完整URL' @@ -2363,6 +2372,8 @@ export default { webSearchEmulation: 'Web Search 模拟', webSearchEmulationHint: '⚠️ 开启后该渠道下所有 Anthropic 分组的账号将自动拦截 web_search 请求,请谨慎操作', webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关', + codexImageGenerationBridge: 'Codex 图片生成桥接', + codexImageGenerationBridgeHint: '开启后,OpenAI 分组的 Codex /responses 文本请求可能会被自动注入 image_generation 工具。仅在路由账号支持图片生成时开启。', basicSettings: '基础设置', addPlatform: '添加平台', noPlatforms: '点击"添加平台"开始配置渠道', @@ -2385,6 +2396,216 @@ export default { } }, + riskControl: { + title: '风控中心', + description: '配置内容审计策略并查看审核记录', + loadFailed: '加载风控中心失败', + saveFailed: '保存内容审计配置失败', + logsFailed: '加载审核记录失败', + saved: '内容审计配置已保存', + refresh: '刷新', + config: '内容审计配置', + configHint: '调用 OpenAI Moderations 进行请求内容评分,命中阈值后按模式处理。', + openSettings: '内容审计设置', + settingsTitle: '内容审计设置', + refreshStatus: '刷新状态', + records: '审核记录', + recordsHint: '展示命中、拦截、异常和已采样记录。', + saveConfig: '保存内容审计配置', + statusFailed: '加载运行状态失败', + enabled: '开启内容审计', + enabledHint: '关闭后即使风控中心菜单启用,也不会审核网关请求。', + mode: '全局模式', + modePreBlock: '前置拦截', + modePreBlockDesc: '每次请求先同步审核最新用户输入,命中后立即拒绝请求。', + modeObserve: '仅观察', + modeObserveDesc: '请求直接放行,最新用户输入进入异步审核队列;命中后只记录、通知和按规则累计。', + modeOff: '关闭', + modeOffDesc: '不执行内容审计,也不会写入审核记录。', + baseUrl: 'OpenAI Base URL', + model: '模型名', + apiKey: 'OpenAI API Key', + apiKeys: 'OpenAI API Keys', + apiKeyCount: '{count} 个 Key', + apiKeyPlaceholder: '请输入 API Key', + apiKeysPlaceholder: '新增 API Key,每行一个;保存后会追加到已保存 Key', + apiKeysPlaceholderReplace: '覆盖保存 API Key,每行一个;保存后会替换全部已保存 Key', + apiKeysPlaceholderKeep: '新增 API Key,每行一个;保存后会追加到已保存 Key', + apiKeysHint: '当前已保存 {count} 个 Key;输入区只用于新增,保存时会增量追加并自动去重。', + apiKeysWriteMode: '写入方式', + apiKeysModeAppend: '增量添加', + apiKeysModeReplace: '覆盖保存', + apiKeysModeAppendHint: '默认模式:保存时追加输入区 Key,并保留已保存 Key。', + apiKeysModeReplaceHint: '覆盖模式:保存时用输入区 Key 替换全部已保存 Key。', + apiKeysReplaceWarning: '覆盖模式', + apiKeysReplaceNoInput: '覆盖保存至少需要输入 1 个 API Key', + apiKeyPlaceholderKeep: '留空保持不变', + apiKeyWillClear: '保存后清除已配置 Key', + apiKeyConfigured: '已配置', + apiKeyTemporary: '待保存', + apiKeyPendingDelete: '待删除', + apiKeyPendingDeleteCount: '待删除 {count} 个 Key', + deleteApiKey: '删除这个 Key', + undoDeleteApiKey: '撤销删除', + inputApiKeyCount: '输入区 {count} 个 Key', + storedApiKeyCount: '已保存 {count} 个 Key', + testInputApiKeys: '测试输入区 Key', + testStoredApiKeys: '测试已保存 Key', + testContentWithStoredApiKey: '用已保存 Key 试跑内容', + testingApiKeys: '测试中', + apiKeyTestNoInput: '请先输入需要测试的 OpenAI API Key', + apiKeyTestDone: 'Key 测试完成,共 {count} 个', + apiKeyTestFailed: '测试 OpenAI API Key 失败', + apiKeyHealth: 'Key 可用状态', + apiKeyFreezeRule: '400 不冻结;401/403 冻结 10 分钟;429/529 冻结 1 分钟;其他 HTTP 错误冻结 10 秒。', + apiKeyRows: '{count} 个 Key', + apiKeyRowsCollapsed: '已隐藏 {count} 个 Key', + apiKeyRowsExpanded: '正在显示全部 {count} 个 Key', + expandApiKeyRows: '展开', + collapseApiKeyRows: '收起', + apiKeyHealthEmpty: '暂无 Key 状态', + apiKeyHealthEmptyHint: '保存 Key 或测试输入区 Key 后会显示可用性。', + apiKeyStatusOk: '可用', + apiKeyStatusError: '异常', + apiKeyStatusFrozen: '冻结', + apiKeyStatusUnknown: '未测试', + apiKeyFailureCount: '失败 {count} 次', + apiKeyLatency: '{ms} ms', + apiKeyHTTPStatus: 'HTTP {status}', + apiKeyFrozenUntil: '冻结至 {time}', + apiKeyLastChecked: '检查于 {time}', + apiKeyNotTested: '尚未测试', + auditTestInput: '审计试跑输入', + auditTestInputHint: '可填写提示词并上传或粘贴图片;图片以 base64 发送,不会保存文件。', + auditTestPromptPlaceholder: '输入要测试的用户提示词;留空时仅测试 Key 可用性。', + auditTestImages: '测试图片', + auditTestImagesHint: '支持上传、拖拽或粘贴图片,最多 1 张,每张不超过 8MB。', + addAuditTestImage: '添加图片', + clearAuditTest: '清空试跑', + auditTestImageLimit: '最多只能添加 {count} 张测试图片', + auditTestImageTooLarge: '单张测试图片不能超过 8MB', + auditTestImageReadFailed: '读取测试图片失败', + auditTestResult: '审计试跑结果', + auditTestHighest: '最高分类 {category},分数 {score}', + auditTestComposite: '综合评分', + auditTestFlagged: '命中阈值', + auditTestPassed: '未命中', + notConfigured: '未配置', + clearApiKey: '清除已保存 Key', + keepApiKey: '保留已保存 Key', + timeoutMs: 'HTTP 超时 (ms)', + retryCount: '失败重试次数', + sampleRate: '采样率', + recordNonHits: '记录未命中输入', + recordNonHitsHint: '开启后会记录抽样但未命中的请求摘要,摘要会先脱敏再入库。', + preHashCheck: '启用前置哈希比对', + preHashCheckHint: '异步审核命中过的输入哈希会被前置拦截;该拦截不发送邮件,也不累计封禁次数。', + flaggedHashCount: '当前哈希集合数量:{count} 个', + flaggedHashHint: '哈希永久保存在 Redis 集合中;可粘贴完整 64 位哈希删除误拦截项,或一键清空全部风险哈希。', + flaggedHashPlaceholder: '粘贴完整 64 位输入哈希', + deleteFlaggedHash: '删除指定哈希', + clearFlaggedHashes: '一键清空', + clearFlaggedHashesConfirm: '确定要清空全部风险输入哈希吗?此操作不会删除审核记录,但会取消所有历史哈希拦截。', + flaggedHashDeleted: '风险哈希已删除', + flaggedHashNotFound: '该风险哈希不存在', + flaggedHashDeleteFailed: '删除风险哈希失败', + flaggedHashesCleared: '已清空 {count} 个风险哈希', + flaggedHashesClearFailed: '清空风险哈希失败', + workerCount: 'Worker 数', + queueSize: '异步队列大小', + blockStatus: '拦截 HTTP 状态码', + blockMessage: '自定义拦截提示', + emailOnHit: '命中后发送邮件', + emailOnHitHint: '开启后每次达到阈值都会向用户发送风控提醒邮件;自动封禁通知始终发送。', + autoBan: '自动封禁用户', + autoBanHint: '命中次数达到阈值后将禁用用户账号、刷新认证缓存并发送封禁通知邮件。', + banThreshold: '封禁触发次数', + violationWindowHours: '累计窗口(小时)', + hitRetentionDays: '命中记录保留(天)', + nonHitRetentionDays: '未命中记录保留(天,最多 3 天)', + violationCount: '{count} 次', + emailSent: '已发邮件', + emailNotSent: '未发邮件', + autoBanned: '已封禁', + unbanUser: '解封', + unbanSuccess: '用户已解封', + unbanFailed: '解封用户失败', + inputDetailTitle: '输入摘要详情', + inputDetailContent: '完整内容', + queueDelay: '排队 {ms} ms', + allGroups: '全部分组', + allGroupsHint: '当前审计全部分组', + selectedGroupsHint: '当前审计指定分组', + groupScope: '审计分组', + groupScopeHint: '开启右侧开关表示全部分组,关闭后选择指定分组。', + selectedGroups: '指定分组', + searchGroups: '搜索分组名称或平台', + noGroups: '暂无可用分组', + emptyLogs: '暂无审核记录', + workerStatus: 'Worker 运行状态', + workerStatusHint: '异步观察任务的队列和 worker 池状态。', + workerPool: 'Worker 池', + workerPoolMeta: '{active} 个处理中,{idle} 个空闲可用,共 {total} 个', + queueUsage: '队列占用', + activeWorkers: '处理中', + idleWorkers: '空闲可用', + workerActive: '正在处理异步审计任务', + workerIdle: '已启动,当前空闲可用', + workerDisabled: '风控或内容审计未启用', + processed: '已处理', + droppedErrors: '丢弃/异常', + autoRefresh: '每 15 秒自动刷新', + lastCleanup: '上次清理:{time}', + cleanupStats: '上次清理删除命中 {hit} 条,未命中 {nonHit} 条', + riskSwitchOff: '系统开关关闭', + tabs: { + basic: '基础', + scope: '审计范围', + runtime: '运行队列', + response: '命中通知', + retention: '日志保留', + }, + overview: { + status: '运行状态', + enabled: '已启用', + disabled: '未启用', + apiKey: 'API Key', + groupScope: '审计范围', + logs: '审核记录', + currentFilter: '当前筛选结果', + }, + filters: { + search: '按用户/Key/摘要搜索', + from: '开始时间', + to: '结束时间', + allGroups: '全部分组', + allEndpoints: '全部端点', + }, + table: { + time: '时间', + group: '分组', + user: '用户', + apiKey: 'API Key', + endpoint: '端点', + result: '结果', + highest: '最高分', + actionMeta: '处置', + latency: '上游耗时', + input: '输入摘要', + }, + result: { + all: '全部结果', + hit: '命中', + blocked: '已拦截', + pass: '未命中', + error: '异常', + }, + action: { + block: '拦截', + error: '异常', + }, + }, + // Channel Monitor channelMonitor: { title: '渠道监控', @@ -2635,6 +2856,11 @@ export default { dataExportSelected: '导出选中', dataExportIncludeProxies: '导出代理(导出账号关联的代理)', dataImport: '导入', + moreActions: '更多操作', + dataActions: '数据操作', + toolActions: '工具', + viewColumns: '列显示', + selectedCount: '已选 {count}', dataExportConfirmMessage: '导出的数据包含账号与代理的敏感信息,请妥善保存。', dataExportConfirm: '确认导出', dataExported: '数据导出成功', @@ -3089,6 +3315,18 @@ export default { responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。', codexCLIOnly: '仅允许 Codex 官方客户端', codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。', + codexImageGenerationBridge: 'Codex 图片生成桥接', + codexImageGenerationBridgeDesc: + '账号级策略优先于渠道和全局配置。仅控制 Codex 走 /responses 文本端点时是否注入 image_generation 工具;不影响独立图片生成接口。', + codexImageGenerationBridgeInherit: '跟随渠道', + codexImageGenerationBridgeInheritDesc: '不写入账号覆盖,继续使用渠道或全局策略。', + codexImageGenerationBridgeEnabled: '强制开启', + codexImageGenerationBridgeEnabledDesc: '允许 Codex /responses 请求获得图片工具注入。', + codexImageGenerationBridgeDisabled: '强制关闭', + codexImageGenerationBridgeDisabledDesc: '阻断 Codex /responses 的图片工具注入。', + codexImageGenerationBridgeBadgeInherit: '渠道策略', + codexImageGenerationBridgeBadgeEnabled: '账号开启', + codexImageGenerationBridgeBadgeDisabled: '账号关闭', compactMode: 'Compact 模式', compactModeDesc: '控制本账号在 /responses/compact 调度中的参与方式。Auto 跟随探测结果,Force On 强制允许,Force Off 强制排除。', @@ -3100,7 +3338,8 @@ export default { '仅在 /responses/compact 请求中生效。当上游 compact 端点需要特殊 compact 模型时使用。', compactSupported: '支持 Compact', compactUnsupported: '不支持 Compact', - compactUnknown: 'Compact 未知', + compactAuto: 'Compact Auto', + compactUnknown: 'Compact Auto', compactLastChecked: '最近探测', testMode: '测试模式', testModeDefault: '常规请求', @@ -3133,7 +3372,7 @@ export default { targetNoWildcard: '目标模型不能包含通配符 *', searchModels: '搜索模型...', noMatchingModels: '没有匹配的模型', - fillRelatedModels: '填入相关模型', + fillRelatedModels: '同步最新支持模型', clearAllModels: '清除所有模型', customModelName: '自定义模型名称', enterCustomModelName: '输入自定义模型名称', @@ -3389,6 +3628,16 @@ export default { refreshTokenAuth: '手动输入 RT', refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个', + codexSessionAuth: 'Codex JSON / AT 批量输入', + codexSessionDesc: '粘贴 Codex JSON 或 accessToken,按第一步配置创建账号。', + codexSessionInputLabel: 'Codex JSON 或 accessToken', + codexSessionPlaceholder: '支持多行,每行一个 token 或 JSON', + codexSessionHint: 'sessionToken 不会作为 refresh_token 保存;未包含 refresh_token 时会按 accessToken 过期时间设置账号过期,无法解析且第一步未设置过期时间时会拒绝导入。', + codexSessionImportAndCreate: '导入并创建账号', + codexSessionEmpty: '请输入 Codex JSON 或 accessToken', + codexSessionImportFailed: 'Codex 账号导入失败', + codexSessionImportSuccess: '导入完成:新增 {created},更新 {updated},跳过 {skipped}', + codexSessionImportPartial: '部分成功:新增 {created},更新 {updated},跳过 {skipped},失败 {failed}', sessionTokenAuth: '手动输入 ST', sessionTokenDesc: '输入您已有的 Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。', sessionTokenPlaceholder: '粘贴您的 Session Token...\n支持多个,每行一个', @@ -5066,6 +5315,7 @@ export default { description: '管理注册、邮箱验证、默认值和 SMTP 设置', tabs: { general: '通用设置', + agreement: '登录条款', features: '功能开关', security: '安全与认证', users: '用户默认值', @@ -5091,6 +5341,13 @@ export default { enabled: '启用可用渠道', enabledHint: '关闭后用户端侧边栏入口隐藏,接口返回空数组。', }, + riskControl: { + title: '风控中心', + description: '启用内容审计菜单和全端点请求审核入口。默认关闭。', + configureLink: '前往 风控中心 配置内容审计', + enabled: '启用风控中心', + enabledHint: '关闭后管理员侧边栏入口隐藏,网关内容审计不会执行。', + }, affiliate: { title: '邀请返利', description: '老用户邀请新用户注册,新用户充值后老用户按比例获得返利额度。默认关闭。', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 3947d13f..68b934db 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -68,6 +68,7 @@ const routes: RouteRecordRaw[] = [ { path: '/auth/callback', name: 'OAuthCallback', + alias: '/auth/oauth/callback', component: () => import('@/views/auth/OAuthCallbackView.vue'), meta: { requiresAuth: false, @@ -143,6 +144,15 @@ const routes: RouteRecordRaw[] = [ title: 'Key Usage', } }, + { + path: '/legal/:documentId', + name: 'LegalDocument', + component: () => import('@/views/public/LegalDocumentView.vue'), + meta: { + requiresAuth: false, + title: 'Legal Document' + } + }, // ==================== User Routes ==================== { @@ -529,6 +539,19 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.settings.description' } }, + { + path: '/admin/risk-control', + name: 'AdminRiskControl', + component: () => import('@/views/admin/RiskControlView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Risk Control', + titleKey: 'admin.riskControl.title', + descriptionKey: 'admin.riskControl.description', + requiresRiskControl: true + } + }, { path: '/admin/usage', name: 'AdminUsage', @@ -657,7 +680,7 @@ let authInitialized = false const navigationLoading = useNavigationLoadingState() // 延迟初始化预加载,传入 router 实例 let routePrefetch: ReturnType | null = null -const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup', '/payment/result'] +const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup', '/payment/result', '/legal'] const BACKEND_MODE_CALLBACK_PATHS = [ '/auth/callback', '/auth/linuxdo/callback', @@ -771,6 +794,14 @@ router.beforeEach((to, _from, next) => { } } + if (to.meta.requiresRiskControl) { + const riskControlEnabled = appStore.cachedPublicSettings?.risk_control_enabled === true + if (!riskControlEnabled) { + next(authStore.isAdmin ? '/admin/settings' : '/dashboard') + return + } + } + // 简易模式下限制访问某些页面 if (authStore.isSimpleMode) { const restrictedPaths = [ diff --git a/frontend/src/router/meta.d.ts b/frontend/src/router/meta.d.ts index 7b2777c2..5c468016 100644 --- a/frontend/src/router/meta.d.ts +++ b/frontend/src/router/meta.d.ts @@ -49,6 +49,12 @@ declare module 'vue-router' { */ requiresPayment?: boolean + /** + * 是否要求风控中心功能开关已启用 + * @default false + */ + requiresRiskControl?: boolean + /** * i18n key for the page title */ diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 876ab5c0..4d701b2e 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -347,6 +347,8 @@ export const useAppStore = defineStore('app', () => { wechat_oauth_mobile_enabled: false, oidc_oauth_enabled: false, oidc_oauth_provider_name: 'OIDC', + github_oauth_enabled: false, + google_oauth_enabled: false, backend_mode_enabled: false, version: siteVersion.value, balance_low_notify_enabled: false, @@ -355,6 +357,7 @@ export const useAppStore = defineStore('app', () => { channel_monitor_enabled: true, channel_monitor_default_interval_seconds: 60, available_channels_enabled: false, + risk_control_enabled: false, affiliate_enabled: false, } } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eb3364c0..17bf4f71 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -34,7 +34,7 @@ export interface NotifyEmailEntry { // ==================== User & Auth Types ==================== -export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat' +export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat' | 'github' | 'google' export interface UserAuthBindingStatus { bound?: boolean @@ -168,6 +168,7 @@ export interface CustomMenuItem { label: string icon_svg: string url: string + page_slug?: string visibility: 'user' | 'admin' sort_order: number } @@ -178,6 +179,12 @@ export interface CustomEndpoint { description: string } +export interface LoginAgreementDocument { + id: string + title: string + content_md: string +} + export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean @@ -186,6 +193,11 @@ export interface PublicSettings { promo_code_enabled: boolean password_reset_enabled: boolean invitation_code_enabled: boolean + login_agreement_enabled?: boolean + login_agreement_mode?: 'modal' | 'checkbox' | string + login_agreement_updated_at?: string + login_agreement_revision?: string + login_agreement_documents?: LoginAgreementDocument[] turnstile_enabled: boolean turnstile_site_key: string site_name: string @@ -197,6 +209,7 @@ export interface PublicSettings { home_content: string hide_ccs_import_button: boolean payment_enabled: boolean + risk_control_enabled: boolean table_default_page_size: number table_page_size_options: number[] custom_menu_items: CustomMenuItem[] @@ -208,6 +221,8 @@ export interface PublicSettings { wechat_oauth_mobile_enabled?: boolean oidc_oauth_enabled: boolean oidc_oauth_provider_name: string + github_oauth_enabled: boolean + google_oauth_enabled: boolean backend_mode_enabled: boolean version: string balance_low_notify_enabled: boolean @@ -1278,6 +1293,51 @@ export interface AdminDataImportResult { errors?: AdminDataImportError[] } +export interface CodexSessionImportRequest { + content?: string + contents?: string[] + name?: string + notes?: string | null + group_ids?: number[] + proxy_id?: number | null + concurrency?: number + priority?: number + rate_multiplier?: number + load_factor?: number | null + expires_at?: number | null + auto_pause_on_expired?: boolean + credential_extras?: Record + extra?: Record + update_existing?: boolean + skip_default_group_bind?: boolean + confirm_mixed_channel_risk?: boolean +} + +export interface CodexSessionImportMessage { + index: number + name?: string + message: string +} + +export interface CodexSessionImportItem { + index: number + name?: string + action: 'created' | 'updated' | 'skipped' | 'failed' + account_id?: number + message?: string +} + +export interface CodexSessionImportResult { + total: number + created: number + updated: number + skipped: number + failed: number + items?: CodexSessionImportItem[] + warnings?: CodexSessionImportMessage[] + errors?: CodexSessionImportMessage[] +} + // ==================== Usage & Redeem Types ==================== export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' | 'invitation' diff --git a/frontend/src/utils/featureFlags.ts b/frontend/src/utils/featureFlags.ts index e0668694..403e7cdc 100644 --- a/frontend/src/utils/featureFlags.ts +++ b/frontend/src/utils/featureFlags.ts @@ -109,6 +109,11 @@ export const FeatureFlags = { mode: 'opt-out', label: 'Payment', }), + riskControl: defineFlag({ + key: 'risk_control_enabled', + mode: 'opt-in', + label: 'Risk Control', + }), affiliate: defineFlag({ key: 'affiliate_enabled', mode: 'opt-in', diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 79d1800e..deb399c9 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -14,7 +14,6 @@