revert: completely remove all Sora functionality
This commit is contained in:
parent
dbb248df52
commit
62e80c602d
@ -102,12 +102,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||
soraAccountRepository := repository.NewSoraAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@ -143,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
@ -184,12 +183,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
soraS3Storage := service.NewSoraS3Storage(settingService)
|
||||
settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
|
||||
soraGenerationRepository := repository.NewSoraGenerationRepository(db)
|
||||
soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
|
||||
soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
@ -223,16 +217,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
|
||||
soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
|
||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
||||
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
@ -243,12 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@ -283,7 +271,6 @@ func provideCleanup(
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
opsSystemLogSink *service.OpsSystemLogSink,
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
@ -331,12 +318,6 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SoraMediaCleanupService", func() error {
|
||||
if soraMediaCleanup != nil {
|
||||
soraMediaCleanup.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"OpsAlertEvaluatorService", func() error {
|
||||
if opsAlertEvaluator != nil {
|
||||
opsAlertEvaluator.Stop()
|
||||
|
||||
@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
&service.OpsCleanupService{},
|
||||
&service.OpsScheduledReportService{},
|
||||
opsSystemLogSinkSvc,
|
||||
&service.SoraMediaCleanupService{},
|
||||
schedulerSnapshotSvc,
|
||||
tokenRefreshSvc,
|
||||
accountExpirySvc,
|
||||
|
||||
@ -77,7 +77,6 @@ type Config struct {
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Sora SoraConfig `mapstructure:"sora"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
@ -197,8 +196,6 @@ type TokenRefreshConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@ -303,59 +300,6 @@ type ConcurrencyConfig struct {
|
||||
PingInterval int `mapstructure:"ping_interval"`
|
||||
}
|
||||
|
||||
// SoraConfig 直连 Sora 配置
|
||||
type SoraConfig struct {
|
||||
Client SoraClientConfig `mapstructure:"client"`
|
||||
Storage SoraStorageConfig `mapstructure:"storage"`
|
||||
}
|
||||
|
||||
// SoraClientConfig 直连 Sora 客户端配置
|
||||
type SoraClientConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
||||
}
|
||||
|
||||
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
||||
type SoraCurlCFFISidecarConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Impersonate string `mapstructure:"impersonate"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
||||
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
||||
}
|
||||
|
||||
// SoraStorageConfig 媒体存储配置
|
||||
type SoraStorageConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
LocalPath string `mapstructure:"local_path"`
|
||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
||||
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||
}
|
||||
|
||||
// SoraStorageCleanupConfig 媒体清理配置
|
||||
type SoraStorageCleanupConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Schedule string `mapstructure:"schedule"`
|
||||
RetentionDays int `mapstructure:"retention_days"`
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
@ -424,24 +368,6 @@ type GatewayConfig struct {
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
|
||||
// Sora 专用配置
|
||||
// SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
|
||||
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
|
||||
// SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
|
||||
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
|
||||
// SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
|
||||
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
|
||||
// SoraStreamMode: stream 强制策略(force/error)
|
||||
SoraStreamMode string `mapstructure:"sora_stream_mode"`
|
||||
// SoraModelFilters: 模型列表过滤配置
|
||||
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
|
||||
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
|
||||
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
|
||||
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
|
||||
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
|
||||
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
|
||||
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
|
||||
|
||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||
@ -639,12 +565,6 @@ type GatewayUsageRecordConfig struct {
|
||||
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
||||
}
|
||||
|
||||
// SoraModelFiltersConfig Sora 模型过滤配置
|
||||
type SoraModelFiltersConfig struct {
|
||||
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
|
||||
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
@ -1402,13 +1322,6 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
viper.SetDefault("gateway.sora_stream_mode", "force")
|
||||
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
|
||||
viper.SetDefault("gateway.sora_media_require_api_key", true)
|
||||
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
||||
@ -1465,45 +1378,12 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// Sora 直连配置
|
||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||
viper.SetDefault("sora.client.max_retries", 3)
|
||||
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||
viper.SetDefault("sora.client.debug", false)
|
||||
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
||||
|
||||
viper.SetDefault("sora.storage.type", "local")
|
||||
viper.SetDefault("sora.storage.local_path", "")
|
||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
||||
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
||||
viper.SetDefault("sora.storage.debug", false)
|
||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
@ -1879,86 +1759,6 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
|
||||
}
|
||||
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
|
||||
switch mode {
|
||||
case "force", "error":
|
||||
default:
|
||||
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
||||
}
|
||||
}
|
||||
if c.Sora.Client.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxRetries < 0 {
|
||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimit < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
||||
}
|
||||
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
||||
}
|
||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
||||
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.Cleanup.Enabled {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
|
||||
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
|
||||
}
|
||||
} else {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
|
||||
}
|
||||
}
|
||||
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
|
||||
return fmt.Errorf("sora.storage.type must be 'local'")
|
||||
}
|
||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||
switch c.Gateway.ConnectionPoolIsolation {
|
||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||
|
||||
@ -1554,93 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
||||
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
||||
}
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
||||
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
@ -22,7 +22,6 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
|
||||
@ -567,15 +567,15 @@ func defaultProxyName(name string) string {
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Only applies to OpenAI OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
// Only enrich OpenAI OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
if platform != service.PlatformOpenAI {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
|
||||
@ -1875,12 +1875,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
|
||||
@ -380,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@ -95,10 +95,6 @@ type CreateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@ -108,8 +104,6 @@ type CreateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||
@ -123,7 +117,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
@ -135,10 +129,6 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@ -148,8 +138,6 @@ type UpdateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||
RequireOAuthOnly *bool `json:"require_oauth_only"`
|
||||
@ -258,10 +246,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@ -269,7 +253,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
@ -313,10 +296,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@ -324,7 +303,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
|
||||
@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct {
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct {
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
|
||||
// Create account
|
||||
|
||||
@ -108,7 +108,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
@ -177,7 +176,6 @@ type UpdateSettingsRequest struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
@ -566,7 +564,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
@ -676,7 +673,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
|
||||
@ -41,7 +41,6 @@ type CreateUserRequest struct {
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
@ -58,7 +57,6 @@ type UpdateUserRequest struct {
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@ -189,7 +187,6 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -225,7 +222,6 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
}
|
||||
}
|
||||
|
||||
@ -172,14 +170,9 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
|
||||
@ -61,7 +61,6 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
@ -128,49 +127,10 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||
type ListSoraS3ProfilesResponse struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
type OverloadCooldownSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@ -26,9 +26,7 @@ type AdminUser struct {
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
@ -84,21 +82,12 @@ type Group struct {
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ const (
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
||||
// prefixes like /antigravity, /openai) to its canonical form.
|
||||
//
|
||||
// "/antigravity/v1/messages" → "/v1/messages"
|
||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||
@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string {
|
||||
// such as /v1/responses/compact preserved from the raw URL).
|
||||
// - Anthropic → /v1/messages
|
||||
// - Gemini → /v1beta/models
|
||||
// - Sora → /v1/chat/completions
|
||||
// - Antigravity → /v1/messages (Claude) or gemini (Gemini)
|
||||
// - Antigravity routes may target either Claude or Gemini, so the
|
||||
// inbound endpoint is used to distinguish.
|
||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
case service.PlatformGemini:
|
||||
return EndpointGeminiModels
|
||||
|
||||
case service.PlatformSora:
|
||||
return EndpointChatCompletions
|
||||
|
||||
case service.PlatformAntigravity:
|
||||
// Antigravity accounts serve both Claude and Gemini.
|
||||
if inbound == EndpointGeminiModels {
|
||||
|
||||
@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai, sora).
|
||||
// Prefixed paths (antigravity, openai).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
// Gemini.
|
||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||
|
||||
// Sora.
|
||||
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
||||
|
||||
// OpenAI — always /v1/responses.
|
||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||
|
||||
@ -859,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
platform = forcedPlatform
|
||||
}
|
||||
|
||||
if platform == service.PlatformSora {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": service.DefaultSoraModels(h.cfg),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
|
||||
|
||||
@ -45,8 +45,6 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
SoraClient *SoraClientHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
|
||||
@ -1,979 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// 上游模型缓存 TTL
|
||||
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
||||
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
||||
)
|
||||
|
||||
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
||||
type SoraClientHandler struct {
|
||||
genService *service.SoraGenerationService
|
||||
quotaService *service.SoraQuotaService
|
||||
s3Storage *service.SoraS3Storage
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
gatewayService *service.GatewayService
|
||||
mediaStorage *service.SoraMediaStorage
|
||||
apiKeyService *service.APIKeyService
|
||||
|
||||
// 上游模型缓存
|
||||
modelCacheMu sync.RWMutex
|
||||
cachedFamilies []service.SoraModelFamily
|
||||
modelCacheTime time.Time
|
||||
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
||||
}
|
||||
|
||||
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
||||
func NewSoraClientHandler(
|
||||
genService *service.SoraGenerationService,
|
||||
quotaService *service.SoraQuotaService,
|
||||
s3Storage *service.SoraS3Storage,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
gatewayService *service.GatewayService,
|
||||
mediaStorage *service.SoraMediaStorage,
|
||||
apiKeyService *service.APIKeyService,
|
||||
) *SoraClientHandler {
|
||||
return &SoraClientHandler{
|
||||
genService: genService,
|
||||
quotaService: quotaService,
|
||||
s3Storage: s3Storage,
|
||||
soraGatewayService: soraGatewayService,
|
||||
gatewayService: gatewayService,
|
||||
mediaStorage: mediaStorage,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRequest 生成请求。
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
MediaType string `json:"media_type"` // video / image,默认 video
|
||||
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
||||
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
||||
}
|
||||
|
||||
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
||||
// POST /api/v1/sora/generate
|
||||
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.MediaType == "" {
|
||||
req.MediaType = "video"
|
||||
}
|
||||
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
||||
|
||||
// 并发数检查(最多 3 个)
|
||||
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if activeCount >= 3 {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
|
||||
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
||||
if h.quotaService != nil {
|
||||
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 API Key ID 和 Group ID
|
||||
var apiKeyID *int64
|
||||
var groupID *int64
|
||||
|
||||
if req.APIKeyID != nil && h.apiKeyService != nil {
|
||||
// 前端传递了 api_key_id,需要校验
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != userID {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
||||
return
|
||||
}
|
||||
if apiKey.Status != service.StatusAPIKeyActive {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
||||
return
|
||||
}
|
||||
apiKeyID = &apiKey.ID
|
||||
groupID = apiKey.GroupID
|
||||
} else if id, ok := c.Get("api_key_id"); ok {
|
||||
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
||||
if v, ok := id.(int64); ok {
|
||||
apiKeyID = &v
|
||||
}
|
||||
}
|
||||
|
||||
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 启动后台异步生成 goroutine
|
||||
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"generation_id": gen.ID,
|
||||
"status": gen.Status,
|
||||
})
|
||||
}
|
||||
|
||||
// processGeneration 后台异步执行 Sora 生成任务。
|
||||
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
||||
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 标记为生成中
|
||||
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
mediaType,
|
||||
videoCount,
|
||||
strings.TrimSpace(imageInput) != "",
|
||||
len(strings.TrimSpace(prompt)),
|
||||
)
|
||||
|
||||
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
||||
if groupID == nil {
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
}
|
||||
|
||||
if h.gatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 选择 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
err,
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
account.ID,
|
||||
account.Name,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
)
|
||||
|
||||
// 构建 chat completions 请求体(非流式)
|
||||
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
||||
|
||||
if h.soraGatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
||||
recorder := httptest.NewRecorder()
|
||||
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
||||
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
|
||||
// 调用 Forward(非流式)
|
||||
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
err,
|
||||
)
|
||||
// 检查是否已取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
||||
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
||||
if mediaURL == "" {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否已被取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
||||
return
|
||||
}
|
||||
|
||||
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
||||
|
||||
usageAdded := false
|
||||
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
||||
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 标记完成
|
||||
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
||||
}
|
||||
|
||||
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
||||
func (h *SoraClientHandler) storeMediaWithDegradation(
|
||||
ctx context.Context, userID int64, mediaType string,
|
||||
mediaURL string, mediaURLs []string,
|
||||
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
||||
urls := mediaURLs
|
||||
if len(urls) == 0 {
|
||||
urls = []string{mediaURL}
|
||||
}
|
||||
|
||||
// 第一层:尝试 S3
|
||||
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
||||
keys := make([]string, 0, len(urls))
|
||||
var totalSize int64
|
||||
allOK := true
|
||||
for _, u := range urls {
|
||||
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
||||
allOK = false
|
||||
// 清理已上传的文件
|
||||
if len(keys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
}
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
totalSize += size
|
||||
}
|
||||
if allOK && len(keys) > 0 {
|
||||
accessURLs := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
allOK = false
|
||||
break
|
||||
}
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
}
|
||||
if allOK && len(accessURLs) > 0 {
|
||||
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二层:尝试本地存储
|
||||
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
||||
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
||||
if err == nil && len(storedPaths) > 0 {
|
||||
firstPath := storedPaths[0]
|
||||
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
||||
if sizeErr != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
||||
}
|
||||
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
||||
}
|
||||
|
||||
// 第三层:保留上游临时 URL
|
||||
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
||||
}
|
||||
|
||||
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
||||
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
||||
body := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
if imageInput != "" {
|
||||
body["image_input"] = imageInput
|
||||
}
|
||||
if videoCount > 1 {
|
||||
body["video_count"] = videoCount
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
return b
|
||||
}
|
||||
|
||||
func normalizeVideoCount(mediaType string, videoCount int) int {
|
||||
if mediaType != "video" {
|
||||
return 1
|
||||
}
|
||||
if videoCount <= 0 {
|
||||
return 1
|
||||
}
|
||||
if videoCount > 3 {
|
||||
return 3
|
||||
}
|
||||
return videoCount
|
||||
}
|
||||
|
||||
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
||||
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
||||
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
||||
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
||||
// 优先从 ForwardResult 获取(OAuth 路径)
|
||||
if result != nil && result.MediaURL != "" {
|
||||
// 尝试从响应体获取完整 URL 列表
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
return result.MediaURL, []string{result.MediaURL}
|
||||
}
|
||||
|
||||
// 从响应体解析(APIKey 路径)
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
||||
func parseMediaURLsFromBody(body []byte) []string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先 media_urls(多图数组)
|
||||
if rawURLs, ok := resp["media_urls"]; ok {
|
||||
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
||||
urls := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
urls = append(urls, s)
|
||||
}
|
||||
}
|
||||
if len(urls) > 0 {
|
||||
return urls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 media_url(单个 URL)
|
||||
if url, ok := resp["media_url"].(string); ok && url != "" {
|
||||
return []string{url}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGenerations 查询生成记录列表。
|
||||
// GET /api/v1/sora/generations
|
||||
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
params := service.SoraGenerationListParams{
|
||||
UserID: userID,
|
||||
Status: c.Query("status"),
|
||||
StorageType: c.Query("storage_type"),
|
||||
MediaType: c.Query("media_type"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
gens, total, err := h.genService.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 为 S3 记录动态生成预签名 URL
|
||||
for _, gen := range gens {
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"data": gens,
|
||||
"total": total,
|
||||
"page": page,
|
||||
})
|
||||
}
|
||||
|
||||
// GetGeneration 查询生成记录详情。
|
||||
// GET /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
response.Success(c, gen)
|
||||
}
|
||||
|
||||
// DeleteGeneration 删除生成记录。
|
||||
// DELETE /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
||||
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
||||
paths := gen.MediaURLs
|
||||
if len(paths) == 0 && gen.MediaURL != "" {
|
||||
paths = []string{gen.MediaURL}
|
||||
}
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
// GetQuota 查询用户存储配额。
|
||||
// GET /api/v1/sora/quota
|
||||
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if h.quotaService == nil {
|
||||
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, quota)
|
||||
}
|
||||
|
||||
// CancelGeneration 取消生成任务。
|
||||
// POST /api/v1/sora/generations/:id/cancel
|
||||
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// 权限校验
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
_ = gen
|
||||
|
||||
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
||||
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已取消"})
|
||||
}
|
||||
|
||||
// SaveToStorage 手动保存 upstream 记录到 S3。
|
||||
// POST /api/v1/sora/generations/:id/save
|
||||
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if gen.StorageType != service.SoraStorageTypeUpstream {
|
||||
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
||||
return
|
||||
}
|
||||
if gen.MediaURL == "" {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
||||
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
sourceURLs := gen.MediaURLs
|
||||
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
||||
sourceURLs = []string{gen.MediaURL}
|
||||
}
|
||||
if len(sourceURLs) == 0 {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
uploadedKeys := make([]string, 0, len(sourceURLs))
|
||||
accessURLs := make([]string, 0, len(sourceURLs))
|
||||
var totalSize int64
|
||||
|
||||
for _, sourceURL := range sourceURLs {
|
||||
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
||||
if uploadErr != nil {
|
||||
if len(uploadedKeys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
}
|
||||
var upstreamErr *service.UpstreamDownloadError
|
||||
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
||||
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
||||
return
|
||||
}
|
||||
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
||||
if err != nil {
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
usageAdded := false
|
||||
if totalSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
if err := h.genService.UpdateStorageForCompleted(
|
||||
c.Request.Context(),
|
||||
id,
|
||||
accessURLs[0],
|
||||
accessURLs,
|
||||
service.SoraStorageTypeS3,
|
||||
uploadedKeys,
|
||||
totalSize,
|
||||
); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "已保存到 S3",
|
||||
"object_key": uploadedKeys[0],
|
||||
"object_keys": uploadedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStorageStatus 返回存储状态。
|
||||
// GET /api/v1/sora/storage-status
|
||||
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
||||
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
||||
s3Healthy := false
|
||||
if s3Enabled {
|
||||
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
||||
}
|
||||
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
||||
response.Success(c, gin.H{
|
||||
"s3_enabled": s3Enabled,
|
||||
"s3_healthy": s3Healthy,
|
||||
"local_enabled": localEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
||||
switch storageType {
|
||||
case service.SoraStorageTypeS3:
|
||||
if h.s3Storage != nil && len(s3Keys) > 0 {
|
||||
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
||||
}
|
||||
}
|
||||
case service.SoraStorageTypeLocal:
|
||||
if h.mediaStorage != nil && len(localPaths) > 0 {
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
||||
func getUserIDFromContext(c *gin.Context) int64 {
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
if id, ok := c.Get("user_id"); ok {
|
||||
switch v := id.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
n, _ := strconv.ParseInt(v, 10, 64)
|
||||
return n
|
||||
}
|
||||
}
|
||||
// 尝试从 JWT claims 获取
|
||||
if id, ok := c.Get("userID"); ok {
|
||||
if v, ok := id.(int64); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func groupIDForLog(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func trimForLog(raw string, maxLen int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if maxLen <= 0 || len(trimmed) <= maxLen {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
// GetModels 获取可用 Sora 模型家族列表。
|
||||
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
||||
// GET /api/v1/sora/models
|
||||
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
||||
families := h.getModelFamilies(c.Request.Context())
|
||||
response.Success(c, families)
|
||||
}
|
||||
|
||||
// getModelFamilies 获取模型家族列表(带缓存)。
|
||||
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
||||
// 读锁检查缓存
|
||||
h.modelCacheMu.RLock()
|
||||
ttl := modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
families := h.cachedFamilies
|
||||
h.modelCacheMu.RUnlock()
|
||||
return families
|
||||
}
|
||||
h.modelCacheMu.RUnlock()
|
||||
|
||||
// 写锁更新缓存
|
||||
h.modelCacheMu.Lock()
|
||||
defer h.modelCacheMu.Unlock()
|
||||
|
||||
// double-check
|
||||
ttl = modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
return h.cachedFamilies
|
||||
}
|
||||
|
||||
// 尝试从上游获取
|
||||
families, err := h.fetchUpstreamModels(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
||||
families = service.BuildSoraModelFamilies()
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = false
|
||||
return families
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = true
|
||||
return families
|
||||
}
|
||||
|
||||
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
||||
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
||||
if h.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gatewayService 未初始化")
|
||||
}
|
||||
|
||||
// 设置 ForcePlatform 用于 Sora 账号选择
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
|
||||
// 选择一个 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
||||
}
|
||||
|
||||
// 仅支持 API Key 类型账号
|
||||
if account.Type != service.AccountTypeAPIKey {
|
||||
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
||||
}
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("账号缺少 api_key")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 构建上游模型列表请求
|
||||
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求上游失败: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 OpenAI 格式的模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(modelsResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("上游返回空模型列表")
|
||||
}
|
||||
|
||||
// 提取模型 ID
|
||||
modelIDs := make([]string, 0, len(modelsResp.Data))
|
||||
for _, m := range modelsResp.Data {
|
||||
modelIDs = append(modelIDs, m.ID)
|
||||
}
|
||||
|
||||
// 转换为模型家族
|
||||
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
||||
if len(families) == 0 {
|
||||
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
||||
}
|
||||
|
||||
return families, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,697 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
//
|
||||
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
streamMode := "force"
|
||||
soraTLSEnabled := true
|
||||
signKey := ""
|
||||
mediaRoot := "/app/data/sora"
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||
streamMode = mode
|
||||
}
|
||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||
mediaRoot = root
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora /v1/chat/completions endpoint
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.sora_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
|
||||
msgsResult := gjson.GetBytes(body, "messages")
|
||||
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
||||
return
|
||||
}
|
||||
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
|
||||
if !clientStream {
|
||||
if h.streamMode == "error" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
||||
return
|
||||
}
|
||||
var err error
|
||||
body, err = sjson.SetBytes(body, "stream", true)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
|
||||
|
||||
platform := ""
|
||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
platform = forced
|
||||
} else if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
if platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := generateOpenAISessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverBody []byte
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
proxyBound := account.ProxyID != nil
|
||||
proxyID := int64(0)
|
||||
if account.ProxyID != nil {
|
||||
proxyID = *account.ProxyID
|
||||
}
|
||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
clientStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
switchCount++
|
||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.String("upstream_error_code", upstreamErrCode),
|
||||
zap.String("upstream_error_message", upstreamErrMsg),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||
continue
|
||||
}
|
||||
reqLog.Error("sora.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("sora.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("sora.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
||||
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||
}
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 404:
|
||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||
}
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||
if headers != nil {
|
||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||
return rayID, mitigated, contentType
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||
lower := strings.ToLower(message)
|
||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// MediaProxy serves local Sora media files.
|
||||
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
||||
h.proxySoraMedia(c, false)
|
||||
}
|
||||
|
||||
// MediaProxySigned serves local Sora media files with signature verification.
|
||||
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
||||
h.proxySoraMedia(c, true)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
||||
rawPath := c.Param("filepath")
|
||||
if rawPath == "" {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
cleaned := path.Clean(rawPath)
|
||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
query := c.Request.URL.Query()
|
||||
if requireSignature {
|
||||
if h.soraMediaSigningKey == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体签名未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
expiresStr := strings.TrimSpace(query.Get("expires"))
|
||||
signature := strings.TrimSpace(query.Get("sig"))
|
||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err != nil || expires <= time.Now().Unix() {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名已过期",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
query.Del("sig")
|
||||
query.Del("expires")
|
||||
signingQuery := query.Encode()
|
||||
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名无效",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(h.soraMediaRoot) == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体目录未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
relative := strings.TrimPrefix(cleaned, "/")
|
||||
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
c.File(localPath)
|
||||
}
|
||||
@ -1,728 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// 编译期接口断言
|
||||
var _ service.SoraClient = (*stubSoraClient)(nil)
|
||||
var _ service.AccountRepository = (*stubAccountRepo)(nil)
|
||||
var _ service.GroupRepository = (*stubGroupRepo)(nil)
|
||||
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
||||
|
||||
type stubSoraClient struct {
|
||||
imageURLs []string
|
||||
}
|
||||
|
||||
func (s *stubSoraClient) Enabled() bool { return true }
|
||||
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
|
||||
return "upload", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
|
||||
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, service.ErrAccountNotFound
|
||||
}
|
||||
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
var result []*service.Account
|
||||
for _, id := range ids {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
_, ok := r.accounts[id]
|
||||
return ok, nil
|
||||
}
|
||||
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return map[string]int64{}, nil
|
||||
}
|
||||
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
|
||||
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
|
||||
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
for _, platform := range platforms {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type stubGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct{}
|
||||
|
||||
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeSimple,
|
||||
Gateway: config.GatewayConfig{
|
||||
SoraStreamMode: "force",
|
||||
MaxAccountSwitches: 1,
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
LoadBatchEnabled: false,
|
||||
},
|
||||
},
|
||||
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.test",
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
|
||||
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
|
||||
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
|
||||
groupRepo := &stubGroupRepo{group: group}
|
||||
|
||||
usageLogRepo := &stubUsageLogRepo{}
|
||||
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
||||
billingService := service.NewBillingService(cfg, nil)
|
||||
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
|
||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
t.Cleanup(func() {
|
||||
billingCacheService.Stop()
|
||||
})
|
||||
|
||||
gatewayService := service.NewGatewayService(
|
||||
accountRepo,
|
||||
groupRepo,
|
||||
usageLogRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
concurrencyService,
|
||||
billingService,
|
||||
nil,
|
||||
billingCacheService,
|
||||
nil,
|
||||
nil,
|
||||
deferredService,
|
||||
nil,
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
nil, // tlsFPProfileService
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||||
|
||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
Status: service.StatusActive,
|
||||
GroupID: &group.ID,
|
||||
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
|
||||
Group: group,
|
||||
}
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
|
||||
|
||||
handler.ChatCompletions(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp["media_url"])
|
||||
}
|
||||
|
||||
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
|
||||
func TestSoraHandler_StreamForcing(t *testing.T) {
|
||||
// 测试 1:stream=false 时 sjson 强制修改为 true
|
||||
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
require.False(t, clientStream)
|
||||
newBody, err := sjson.SetBytes(body, "stream", true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
|
||||
|
||||
// 测试 2:stream=true 时不修改
|
||||
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
|
||||
require.True(t, gjson.GetBytes(body2, "stream").Bool())
|
||||
|
||||
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
|
||||
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
|
||||
require.False(t, gjson.GetBytes(body3, "stream").Bool())
|
||||
}
|
||||
|
||||
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
|
||||
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
||||
// model 缺失
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
|
||||
|
||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult1b := gjson.GetBytes(body1b, "model")
|
||||
require.True(t, modelResult1b.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult1b.Type)
|
||||
|
||||
// messages 缺失
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
|
||||
|
||||
// messages 不是 JSON 数组(字符串)
|
||||
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
||||
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
|
||||
|
||||
// messages 是对象而非数组 → IsArray 返回 false
|
||||
body4 := []byte(`{"model":"sora","messages":{}}`)
|
||||
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
|
||||
|
||||
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
|
||||
body5 := []byte(`{"model":"sora","messages":[]}`)
|
||||
msgsResult := gjson.GetBytes(body5, "messages")
|
||||
require.True(t, msgsResult.IsArray())
|
||||
require.Equal(t, 0, len(msgsResult.Array()))
|
||||
|
||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
|
||||
}
|
||||
|
||||
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
||||
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 从 body 提取 prompt_cache_key
|
||||
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
hash := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// 无 prompt_cache_key 且无 header → 空 hash
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
hash2 := generateOpenAISessionHash(c, body2)
|
||||
require.Empty(t, hash2)
|
||||
|
||||
// header 优先于 body
|
||||
c.Request.Header.Set("session_id", "from-header")
|
||||
hash3 := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash3)
|
||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||
}
|
||||
|
||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号",
|
||||
errType: "upstream_error",
|
||||
message: `upstream returned "invalid" payload`,
|
||||
},
|
||||
{
|
||||
name: "包含换行和制表符",
|
||||
errType: "rate_limit_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠",
|
||||
errType: "upstream_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||
require.Equal(t, tt.errType, errorObj["type"])
|
||||
require.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
headers.Set("content-type", "text/html")
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||
require.Equal(t, "challenge", mitigated)
|
||||
require.Equal(t, "text/html", contentType)
|
||||
}
|
||||
@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
@ -86,8 +86,6 @@ func ProvideHandlers(
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
soraClientHandler *SoraClientHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
@ -104,8 +102,6 @@ func ProvideHandlers(
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
SoraClient: soraClientHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
|
||||
@ -17,8 +17,6 @@ import (
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
@ -39,8 +37,6 @@ const (
|
||||
const (
|
||||
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
||||
OAuthPlatformOpenAI = "openai"
|
||||
// OAuthPlatformSora uses Sora OAuth client.
|
||||
OAuthPlatformSora = "sora"
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
|
||||
}
|
||||
|
||||
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
||||
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
|
||||
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
|
||||
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case OAuthPlatformSora:
|
||||
return ClientID, false
|
||||
default:
|
||||
return ClientID, true
|
||||
}
|
||||
return ClientID, true
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
|
||||
@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
|
||||
// 但不启用 codex_cli_simplified_flow。
|
||||
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "" {
|
||||
t.Fatalf("codex flow should be empty for sora, got=%q", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1692,20 +1692,13 @@ func itoa(v int) string {
|
||||
}
|
||||
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||
//
|
||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
||||
//
|
||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||
//
|
||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||
dbaccount.DeletedAtIsNil(),
|
||||
func(s *entsql.Selector) {
|
||||
path := sqljson.Path(key)
|
||||
|
||||
@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldSoraImagePrice360,
|
||||
group.FieldSoraImagePrice540,
|
||||
group.FieldSoraVideoPricePerRequest,
|
||||
group.FieldSoraVideoPricePerRequestHd,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
|
||||
@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
|
||||
@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
|
||||
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
||||
wantClientID := openai.SoraClientID
|
||||
wantClientID := "custom-exchange-client-id"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
|
||||
//
|
||||
// 设计说明:
|
||||
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
|
||||
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
|
||||
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
|
||||
type soraAccountRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
|
||||
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
|
||||
return &soraAccountRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// Upsert 创建或更新 Sora 账号扩展信息
|
||||
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
|
||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
||||
accessToken, accessOK := updates["access_token"].(string)
|
||||
refreshToken, refreshOK := updates["refresh_token"].(string)
|
||||
sessionToken, sessionOK := updates["session_token"].(string)
|
||||
|
||||
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
|
||||
if !sessionOK {
|
||||
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
|
||||
}
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_accounts
|
||||
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
|
||||
updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`, accountID, sessionToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
||||
ON CONFLICT (account_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
|
||||
updated_at = NOW()
|
||||
`, accountID, accessToken, refreshToken, sessionToken)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
|
||||
FROM sora_accounts
|
||||
WHERE account_id = $1
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, nil // 记录不存在
|
||||
}
|
||||
|
||||
var sa service.SoraAccount
|
||||
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sa, nil
|
||||
}
|
||||
|
||||
// Delete 删除 Sora 账号扩展信息
|
||||
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM sora_accounts WHERE account_id = $1
|
||||
`, accountID)
|
||||
return err
|
||||
}
|
||||
@ -1,419 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_generations 表。
|
||||
type soraGenerationRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
|
||||
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
|
||||
return &soraGenerationRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
|
||||
func (r *soraGenerationRepository) CreatePendingWithLimit(
|
||||
ctx context.Context,
|
||||
gen *service.SoraGeneration,
|
||||
activeStatuses []string,
|
||||
maxActive int64,
|
||||
) error {
|
||||
if gen == nil {
|
||||
return fmt.Errorf("generation is nil")
|
||||
}
|
||||
if maxActive <= 0 {
|
||||
return r.Create(ctx, gen)
|
||||
}
|
||||
if len(activeStatuses) == 0 {
|
||||
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
|
||||
}
|
||||
|
||||
tx, err := r.sql.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
|
||||
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(activeStatuses))
|
||||
args := make([]any, 0, 1+len(activeStatuses))
|
||||
args = append(args, gen.UserID)
|
||||
for i, s := range activeStatuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
countQuery := fmt.Sprintf(
|
||||
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
|
||||
strings.Join(placeholders, ","),
|
||||
)
|
||||
var activeCount int64
|
||||
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if activeCount >= maxActive {
|
||||
return service.ErrSoraGenerationConcurrencyLimit
|
||||
}
|
||||
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations WHERE id = $1
|
||||
`, id).Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("生成记录不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
var completedAt *time.Time
|
||||
if gen.CompletedAt != nil {
|
||||
completedAt = gen.CompletedAt
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations SET
|
||||
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
|
||||
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
|
||||
error_message = $9, completed_at = $10
|
||||
WHERE id = $1
|
||||
`,
|
||||
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
|
||||
gen.ErrorMessage, completedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
|
||||
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, upstream_task_id = $3
|
||||
WHERE id = $1 AND status = $4
|
||||
`,
|
||||
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
|
||||
func (r *soraGenerationRepository) UpdateCompletedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
media_url = $3,
|
||||
media_urls = $4,
|
||||
file_size_bytes = $5,
|
||||
storage_type = $6,
|
||||
s3_object_keys = $7,
|
||||
error_message = '',
|
||||
completed_at = $8
|
||||
WHERE id = $1 AND status IN ($9, $10)
|
||||
`,
|
||||
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
|
||||
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
|
||||
func (r *soraGenerationRepository) UpdateFailedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
errMsg string,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
error_message = $3,
|
||||
completed_at = $4
|
||||
WHERE id = $1 AND status IN ($5, $6)
|
||||
`,
|
||||
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
|
||||
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, completed_at = $3
|
||||
WHERE id = $1 AND status IN ($4, $5)
|
||||
`,
|
||||
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
|
||||
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET media_url = $2,
|
||||
media_urls = $3,
|
||||
file_size_bytes = $4,
|
||||
storage_type = $5,
|
||||
s3_object_keys = $6
|
||||
WHERE id = $1 AND status = $7
|
||||
`,
|
||||
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
|
||||
// 构建 WHERE 条件
|
||||
conditions := []string{"user_id = $1"}
|
||||
args := []any{params.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if params.Status != "" {
|
||||
// 支持逗号分隔的多状态
|
||||
statuses := strings.Split(params.Status, ",")
|
||||
placeholders := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.StorageType != "" {
|
||||
storageTypes := strings.Split(params.StorageType, ",")
|
||||
placeholders := make([]string, len(storageTypes))
|
||||
for i, s := range storageTypes {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.MediaType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
|
||||
args = append(args, params.MediaType)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := "WHERE " + strings.Join(conditions, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
|
||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
listQuery := fmt.Sprintf(`
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIdx, argIdx+1)
|
||||
args = append(args, params.PageSize, offset)
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var results []*service.SoraGeneration
|
||||
for rows.Next() {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
if err := rows.Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
results = append(results, gen)
|
||||
}
|
||||
|
||||
return results, total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if len(statuses) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(statuses))
|
||||
args := []any{userID}
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
|
||||
var count int64
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
|
||||
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
|
||||
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
|
||||
WHERE id = $1
|
||||
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
|
||||
if err == nil {
|
||||
return newUsed, nil
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// 区分用户不存在和配额冲突
|
||||
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
|
||||
if existsErr != nil {
|
||||
return 0, existsErr
|
||||
}
|
||||
if !exists {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, service.ErrSoraStorageQuotaExceeded
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
|
||||
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
|
||||
WHERE id = $1
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes}, &newUsed)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return newUsed, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
|
||||
@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
|
||||
return strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/responses")
|
||||
}
|
||||
|
||||
|
||||
@ -109,7 +109,6 @@ func registerRoutes(
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@ -34,8 +34,6 @@ func RegisterAdminRoutes(
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
||||
registerSoraOAuthRoutes(admin, h)
|
||||
|
||||
// Gemini OAuth
|
||||
registerGeminiOAuthRoutes(admin, h)
|
||||
@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
sora := admin.Group("/sora")
|
||||
{
|
||||
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
||||
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
gemini := admin.Group("/gemini")
|
||||
{
|
||||
|
||||
@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
|
||||
if soraMaxBodySize <= 0 {
|
||||
soraMaxBodySize = cfg.Gateway.MaxBodySize
|
||||
}
|
||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
endpointNorm := handler.InboundEndpointMiddleware()
|
||||
@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
||||
}
|
||||
|
||||
// Sora 专用路由(强制使用 sora 平台)
|
||||
soraV1 := r.Group("/sora/v1")
|
||||
soraV1.Use(soraBodyLimit)
|
||||
soraV1.Use(clientRequestID)
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(endpointNorm)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
soraV1.GET("/models", h.Gateway.Models)
|
||||
}
|
||||
|
||||
// Sora 媒体代理(可选 API Key 验证)
|
||||
if cfg.Gateway.SoraMediaRequireAPIKey {
|
||||
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
|
||||
} else {
|
||||
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
|
||||
}
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||
|
||||
@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine {
|
||||
&handler.Handlers{
|
||||
Gateway: &handler.GatewayHandler{},
|
||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||
SoraGateway: &handler.SoraGatewayHandler{},
|
||||
},
|
||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
@ -1,36 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
|
||||
func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
|
||||
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
|
||||
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
|
||||
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
|
||||
authenticated.GET("/quota", h.SoraClient.GetQuota)
|
||||
authenticated.GET("/models", h.SoraClient.GetModels)
|
||||
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
|
||||
}
|
||||
}
|
||||
@ -28,8 +28,7 @@ type AccountRepository interface {
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
||||
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号
|
||||
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||
// for all accounts that have been synced from CRS.
|
||||
|
||||
@ -13,18 +13,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
|
||||
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
|
||||
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@ -71,13 +62,8 @@ type AccountTestService struct {
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
soraTestGuardMu sync.Mutex
|
||||
soraTestLastRun map[int64]time.Time
|
||||
soraTestCooldown time.Duration
|
||||
}
|
||||
|
||||
const defaultSoraTestCooldown = 10 * time.Second
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
@ -94,8 +80,6 @@ func NewAccountTestService(
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
soraTestLastRun: make(map[int64]time.Time),
|
||||
soraTestCooldown: defaultSoraTestCooldown,
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
return s.testSoraAccountConnection(c, account)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
@ -634,697 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
type soraProbeStep struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
HTTPStatus int `json:"http_status,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type soraProbeSummary struct {
|
||||
Status string `json:"status"`
|
||||
Steps []soraProbeStep `json:"steps"`
|
||||
}
|
||||
|
||||
type soraProbeRecorder struct {
|
||||
steps []soraProbeStep
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
|
||||
r.steps = append(r.steps, soraProbeStep{
|
||||
Name: name,
|
||||
Status: status,
|
||||
HTTPStatus: httpStatus,
|
||||
ErrorCode: strings.TrimSpace(errorCode),
|
||||
Message: strings.TrimSpace(message),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) finalize() soraProbeSummary {
|
||||
meSuccess := false
|
||||
partial := false
|
||||
for _, step := range r.steps {
|
||||
if step.Name == "me" {
|
||||
meSuccess = strings.EqualFold(step.Status, "success")
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(step.Status, "failed") {
|
||||
partial = true
|
||||
}
|
||||
}
|
||||
|
||||
status := "success"
|
||||
if !meSuccess {
|
||||
status = "failed"
|
||||
} else if partial {
|
||||
status = "partial_success"
|
||||
}
|
||||
|
||||
return soraProbeSummary{
|
||||
Status: status,
|
||||
Steps: append([]soraProbeStep(nil), r.steps...),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
summary := rec.finalize()
|
||||
code := ""
|
||||
for _, step := range summary.Steps {
|
||||
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
|
||||
code = step.ErrorCode
|
||||
break
|
||||
}
|
||||
}
|
||||
s.sendEvent(c, TestEvent{
|
||||
Type: "sora_test_result",
|
||||
Status: summary.Status,
|
||||
Code: code,
|
||||
Data: summary,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
|
||||
if accountID <= 0 {
|
||||
return 0, true
|
||||
}
|
||||
s.soraTestGuardMu.Lock()
|
||||
defer s.soraTestGuardMu.Unlock()
|
||||
|
||||
if s.soraTestLastRun == nil {
|
||||
s.soraTestLastRun = make(map[int64]time.Time)
|
||||
}
|
||||
cooldown := s.soraTestCooldown
|
||||
if cooldown <= 0 {
|
||||
cooldown = defaultSoraTestCooldown
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
|
||||
elapsed := now.Sub(lastRun)
|
||||
if elapsed < cooldown {
|
||||
return cooldown - elapsed, false
|
||||
}
|
||||
}
|
||||
s.soraTestLastRun[accountID] = now
|
||||
return 0, true
|
||||
}
|
||||
|
||||
func ceilSeconds(d time.Duration) int {
|
||||
if d <= 0 {
|
||||
return 1
|
||||
}
|
||||
sec := int(d / time.Second)
|
||||
if d%time.Second != 0 {
|
||||
sec++
|
||||
}
|
||||
if sec < 1 {
|
||||
sec = 1
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
|
||||
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
|
||||
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 验证 base_url 格式
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
|
||||
}
|
||||
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
|
||||
|
||||
// 设置 SSE 头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
|
||||
|
||||
// 构建轻量级 prompt-enhance 请求作为连通性测试
|
||||
testPayload := map[string]any{
|
||||
"model": "prompt-enhance-short-10s",
|
||||
"messages": []map[string]string{{"role": "user", "content": "test"}},
|
||||
"stream": false,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(testPayload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "构建测试请求失败")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 获取代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
|
||||
}
|
||||
|
||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
|
||||
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
|
||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||
// apikey 类型走独立测试流程
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
return s.testSoraAPIKeyAccountConnection(c, account)
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
recorder := &soraProbeRecorder{}
|
||||
|
||||
authToken := account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// 使用 Sora 客户端标准请求头
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
soraTLSProfile := s.resolveSoraTLSProfile()
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "network_error", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
|
||||
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
|
||||
case strings.TrimSpace(upstreamMessage) != "":
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
|
||||
default:
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||
}
|
||||
}
|
||||
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
|
||||
|
||||
// 解析 /me 响应,提取用户信息
|
||||
var meResp map[string]any
|
||||
if err := json.Unmarshal(body, &meResp); err != nil {
|
||||
// 能收到 200 就说明 token 有效
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
|
||||
} else {
|
||||
// 尝试提取用户名或邮箱信息
|
||||
info := "Sora connection OK"
|
||||
if name, ok := meResp["name"].(string); ok && name != "" {
|
||||
info = fmt.Sprintf("Sora connection OK - User: %s", name)
|
||||
} else if email, ok := meResp["email"].(string); ok && email != "" {
|
||||
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
||||
}
|
||||
|
||||
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
||||
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
||||
if err == nil {
|
||||
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
||||
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
subReq.Header.Set("Accept", "application/json")
|
||||
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
|
||||
if subErr != nil {
|
||||
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||
} else {
|
||||
subBody, _ := io.ReadAll(subResp.Body)
|
||||
_ = subResp.Body.Close()
|
||||
if subResp.StatusCode == http.StatusOK {
|
||||
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
|
||||
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||
}
|
||||
} else {
|
||||
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
|
||||
} else {
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
|
||||
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
|
||||
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testSora2Capabilities(
|
||||
c *gin.Context,
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
authToken string,
|
||||
proxyURL string,
|
||||
tlsProfile *tlsfingerprint.Profile,
|
||||
recorder *soraProbeRecorder,
|
||||
) {
|
||||
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraInviteMineURL,
|
||||
proxyURL,
|
||||
tlsProfile,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
|
||||
if inviteStatus == http.StatusUnauthorized {
|
||||
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraBootstrapURL,
|
||||
proxyURL,
|
||||
tlsProfile,
|
||||
)
|
||||
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
|
||||
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraInviteMineURL,
|
||||
proxyURL,
|
||||
tlsProfile,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
} else if recorder != nil {
|
||||
code := ""
|
||||
msg := ""
|
||||
if bootstrapErr != nil {
|
||||
code = "network_error"
|
||||
msg = bootstrapErr.Error()
|
||||
}
|
||||
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if inviteStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
|
||||
}
|
||||
|
||||
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
|
||||
}
|
||||
|
||||
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
account,
|
||||
authToken,
|
||||
soraRemainingURL,
|
||||
proxyURL,
|
||||
tlsProfile,
|
||||
)
|
||||
if remainingErr != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
|
||||
return
|
||||
}
|
||||
if remainingStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
|
||||
}
|
||||
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) fetchSoraTestEndpoint(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
authToken string,
|
||||
url string,
|
||||
proxyURL string,
|
||||
tlsProfile *tlsfingerprint.Profile,
|
||||
) (int, http.Header, []byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return resp.StatusCode, resp.Header, nil, readErr
|
||||
}
|
||||
return resp.StatusCode, resp.Header, body, nil
|
||||
}
|
||||
|
||||
func parseSoraSubscriptionSummary(body []byte) string {
|
||||
var subResp struct {
|
||||
Data []struct {
|
||||
Plan struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
} `json:"plan"`
|
||||
EndTS string `json:"end_ts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &subResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
if len(subResp.Data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
first := subResp.Data[0]
|
||||
parts := make([]string, 0, 3)
|
||||
if first.Plan.Title != "" {
|
||||
parts = append(parts, first.Plan.Title)
|
||||
}
|
||||
if first.Plan.ID != "" {
|
||||
parts = append(parts, first.Plan.ID)
|
||||
}
|
||||
if first.EndTS != "" {
|
||||
parts = append(parts, "end="+first.EndTS)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "Subscription: " + strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func parseSoraInviteSummary(body []byte) string {
|
||||
var inviteResp struct {
|
||||
InviteCode string `json:"invite_code"`
|
||||
RedeemedCount int64 `json:"redeemed_count"`
|
||||
TotalCount int64 `json:"total_count"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &inviteResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := []string{"Sora2: supported"}
|
||||
if inviteResp.InviteCode != "" {
|
||||
parts = append(parts, "invite="+inviteResp.InviteCode)
|
||||
}
|
||||
if inviteResp.TotalCount > 0 {
|
||||
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
|
||||
}
|
||||
return strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func parseSoraRemainingSummary(body []byte) string {
|
||||
var remainingResp struct {
|
||||
RateLimitAndCreditBalance struct {
|
||||
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
|
||||
RateLimitReached bool `json:"rate_limit_reached"`
|
||||
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
|
||||
} `json:"rate_limit_and_credit_balance"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &remainingResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
info := remainingResp.RateLimitAndCreditBalance
|
||||
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
|
||||
if info.RateLimitReached {
|
||||
parts = append(parts, "rate_limited=true")
|
||||
}
|
||||
if info.AccessResetsInSeconds > 0 {
|
||||
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
|
||||
}
|
||||
return strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
|
||||
if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
|
||||
// Sora TLS fingerprint enabled — use built-in default profile
|
||||
return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
|
||||
}
|
||||
return nil // disabled
|
||||
}
|
||||
|
||||
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
return soraerror.ExtractCloudflareRayID(headers, body)
|
||||
}
|
||||
|
||||
func extractSoraEgressIPHint(headers http.Header) string {
|
||||
if headers == nil {
|
||||
return "unknown"
|
||||
}
|
||||
candidates := []string{
|
||||
"x-openai-public-ip",
|
||||
"x-envoy-external-address",
|
||||
"cf-connecting-ip",
|
||||
"x-forwarded-for",
|
||||
}
|
||||
for _, key := range candidates {
|
||||
if value := strings.TrimSpace(headers.Get(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func sanitizeProxyURLForLog(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "<invalid_proxy_url>"
|
||||
}
|
||||
if u.User != nil {
|
||||
u.User = nil
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func endpointPathForLog(endpoint string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(endpoint))
|
||||
if err != nil || parsed.Path == "" {
|
||||
return endpoint
|
||||
}
|
||||
return parsed.Path
|
||||
}
|
||||
|
||||
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
|
||||
accountID := int64(0)
|
||||
platform := ""
|
||||
proxyID := "none"
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
platform = account.Platform
|
||||
if account.ProxyID != nil {
|
||||
proxyID = fmt.Sprintf("%d", *account.ProxyID)
|
||||
}
|
||||
}
|
||||
cfRay := extractCloudflareRayID(headers, body)
|
||||
if cfRay == "" {
|
||||
cfRay = "unknown"
|
||||
}
|
||||
log.Printf(
|
||||
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
|
||||
accountID,
|
||||
platform,
|
||||
endpoint,
|
||||
endpointPathForLog(endpoint),
|
||||
proxyID,
|
||||
sanitizeProxyURLForLog(proxyURL),
|
||||
cfRay,
|
||||
extractSoraEgressIPHint(headers),
|
||||
)
|
||||
}
|
||||
|
||||
func truncateSoraErrorBody(body []byte, max int) string {
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
|
||||
@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx, recorder := newSoraTestContext()
|
||||
ctx, recorder := newTestContext()
|
||||
svc := &AccountTestService{}
|
||||
|
||||
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -14,6 +15,14 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
@ -34,7 +43,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
@ -68,7 +77,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
ctx, _ := newTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
|
||||
@ -1,320 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type queuedHTTPUpstream struct {
|
||||
responses []*http.Response
|
||||
requests []*http.Request
|
||||
tlsFlags []bool
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected Do call")
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
u.requests = append(u.requests, req)
|
||||
u.tlsFlags = append(u.tlsFlags, profile != nil)
|
||||
if len(u.responses) == 0 {
|
||||
return nil, fmt.Errorf("no mocked response")
|
||||
}
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func newJSONResponse(status int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||
resp := newJSONResponse(status, body)
|
||||
resp.Header.Set(key, value)
|
||||
return resp
|
||||
}
|
||||
|
||||
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
|
||||
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
TLSFingerprint: config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
DisableTLSFingerprint: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
|
||||
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"test_start"`)
|
||||
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
||||
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||
require.Contains(t, body, "Subscription check returned 403")
|
||||
require.Contains(t, body, "Sora2 invite check returned 401")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"partial_success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "HTTP 429")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "token_invalidated")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.Contains(t, body, "token_invalidated")
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
soraTestCooldown: time.Hour,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c1, _ := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c1, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
c2, rec2 := newSoraTestContext()
|
||||
err = svc.testSoraAccountConnection(c2, account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "测试过于频繁")
|
||||
body := rec2.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"code":"test_rate_limited"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestSanitizeProxyURLForLog(t *testing.T) {
|
||||
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
||||
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
||||
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
||||
}
|
||||
|
||||
func TestExtractSoraEgressIPHint(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("x-openai-public-ip", "203.0.113.10")
|
||||
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
||||
|
||||
h2 := make(http.Header)
|
||||
h2.Set("x-envoy-external-address", "198.51.100.9")
|
||||
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
||||
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
||||
}
|
||||
@ -15,7 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/httputil"
|
||||
)
|
||||
|
||||
// AdminService interface defines admin management operations
|
||||
@ -111,7 +111,6 @@ type CreateUserInput struct {
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
SoraStorageQuotaBytes int64
|
||||
}
|
||||
|
||||
type UpdateUserInput struct {
|
||||
@ -126,7 +125,6 @@ type UpdateUserInput struct {
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
SoraStorageQuotaBytes *int64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@ -143,11 +141,6 @@ type CreateGroupInput struct {
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
@ -158,8 +151,6 @@ type CreateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
@ -184,11 +175,6 @@ type UpdateGroupInput struct {
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
@ -199,8 +185,6 @@ type UpdateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool
|
||||
DefaultMappedModel *string
|
||||
@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
|
||||
http.StatusOK: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
Target: "sora",
|
||||
URL: "https://sora.chatgpt.com/backend/me",
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
@ -448,7 +424,6 @@ type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
@ -473,7 +448,6 @@ func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
@ -492,7 +466,6 @@ func NewAdminService(
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
@ -582,7 +555,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
|
||||
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
|
||||
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
|
||||
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
SoraImagePrice360: soraImagePrice360,
|
||||
SoraImagePrice540: soraImagePrice540,
|
||||
SoraVideoPricePerRequest: soraVideoPrice,
|
||||
SoraVideoPricePerRequestHD: soraVideoPriceHD,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: input.RequireOAuthOnly,
|
||||
RequirePrivacySet: input.RequirePrivacySet,
|
||||
@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.ImagePrice4K != nil {
|
||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||
}
|
||||
if input.SoraImagePrice360 != nil {
|
||||
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
|
||||
}
|
||||
if input.SoraImagePrice540 != nil {
|
||||
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
|
||||
}
|
||||
if input.SoraVideoPricePerRequest != nil {
|
||||
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
|
||||
}
|
||||
if input.SoraVideoPricePerRequestHD != nil {
|
||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
}
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := input.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Notes: normalizeAccountNotes(input.Notes),
|
||||
@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
|
||||
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
|
||||
soraUpdates := map[string]any{
|
||||
"access_token": account.GetCredential("access_token"),
|
||||
"refresh_token": account.GetCredential("refresh_token"),
|
||||
}
|
||||
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
|
||||
// 只记录警告日志,不阻塞账号创建
|
||||
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(groupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||
@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := account.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
|
||||
body = body[:proxyQualityMaxBodyBytes]
|
||||
}
|
||||
|
||||
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
// Cloudflare challenge 检测
|
||||
if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
item.Status = "challenge"
|
||||
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
|
||||
item.Message = "Sora 命中 Cloudflare challenge"
|
||||
item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
|
||||
item.Message = "命中 Cloudflare challenge"
|
||||
return item
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
|
||||
require.Contains(t, result.Summary, "挑战 1 项")
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
||||
func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("cf-ray", "test-ray-123")
|
||||
@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "sora",
|
||||
Target: "openai",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
|
||||
@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
|
||||
@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||
@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||
|
||||
@ -808,14 +808,6 @@ type ImagePriceConfig struct {
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// SoraPriceConfig Sora 按次计费配置
|
||||
type SoraPriceConfig struct {
|
||||
ImagePrice360 *float64
|
||||
ImagePrice540 *float64
|
||||
VideoPricePerRequest *float64
|
||||
VideoPricePerRequestHD *float64
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraImageCost 计算 Sora 图片按次费用
|
||||
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "540":
|
||||
if groupConfig.ImagePrice540 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice540
|
||||
}
|
||||
default:
|
||||
if groupConfig.ImagePrice360 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice360
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraVideoCost 计算 Sora 视频按次费用
|
||||
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
||||
if groupConfig.VideoPricePerRequestHD != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequestHD
|
||||
}
|
||||
}
|
||||
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequest
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
|
||||
@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
|
||||
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price := 0.5
|
||||
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
|
||||
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
|
||||
|
||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
hdPrice := 1.0
|
||||
normalPrice := 0.5
|
||||
cfg := &SoraPriceConfig{
|
||||
VideoPricePerRequest: &normalPrice,
|
||||
VideoPricePerRequestHD: &hdPrice,
|
||||
}
|
||||
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
|
||||
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestIsModelSupported(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "not initialized")
|
||||
}
|
||||
|
||||
func TestCalculateSoraImageCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price360 := 0.05
|
||||
price540 := 0.08
|
||||
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
|
||||
|
||||
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
|
||||
|
||||
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
|
||||
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
|
||||
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
|
||||
// 使用空的 fallback prices 让 GetModelPricing 失败
|
||||
svc := &BillingService{
|
||||
|
||||
@ -24,7 +24,6 @@ const (
|
||||
PlatformOpenAI = domain.PlatformOpenAI
|
||||
PlatformGemini = domain.PlatformGemini
|
||||
PlatformAntigravity = domain.PlatformAntigravity
|
||||
PlatformSora = domain.PlatformSora
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@ -107,7 +106,6 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
@ -199,27 +197,6 @@ const (
|
||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
|
||||
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
|
||||
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
|
||||
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
|
||||
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
|
||||
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
|
||||
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
|
||||
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
|
||||
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
|
||||
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
|
||||
|
||||
// =========================
|
||||
// Sora 用户存储配额
|
||||
// =========================
|
||||
|
||||
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
|
||||
|
||||
// =========================
|
||||
// Claude Code Version Check
|
||||
// =========================
|
||||
|
||||
@ -60,13 +60,6 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
// MediaType 媒体类型常量
|
||||
const (
|
||||
MediaTypeImage = "image"
|
||||
MediaTypeVideo = "video"
|
||||
MediaTypePrompt = "prompt"
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@ -510,10 +503,6 @@ type ForwardResult struct {
|
||||
// 图片生成计费字段(图片生成模型使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
|
||||
// Sora 媒体字段
|
||||
MediaType string // image / video / prompt
|
||||
MediaURL string // 生成后的媒体地址(可选)
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
@ -1341,6 +1330,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
@ -1349,12 +1343,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return excluded
|
||||
}
|
||||
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
|
||||
// 获取模型路由配置(仅 anthropic 平台)
|
||||
var routingAccountIDs []int64
|
||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||
@ -1442,24 +1430,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
var stickyCacheMissReason string
|
||||
|
||||
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
if s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
|
||||
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
|
||||
|
||||
if rpmPass { // 粘性会话窗口费用+RPM 检查
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
@ -1473,49 +1456,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
if stickyCacheMissReason == "" {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
stickyCacheMissReason = "wait_queue_full"
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
} else if !gatePass {
|
||||
stickyCacheMissReason = "gate_check"
|
||||
} else {
|
||||
stickyCacheMissReason = "rpm_red"
|
||||
}
|
||||
|
||||
// 记录粘性缓存未命中的结构化日志
|
||||
if stickyCacheMissReason != "" {
|
||||
baseRPM := stickyAccount.GetBaseRPM()
|
||||
var currentRPM int
|
||||
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
|
||||
currentRPM = count
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
|
||||
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
|
||||
stickyAccountID, shortSessionHash(sessionHash))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1621,7 +1582,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
account, ok := accountByID[accountID]
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
// Check if the account needs sticky session cleanup
|
||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
@ -1637,7 +1597,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
@ -1652,10 +1611,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@ -1971,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if platform == PlatformSora {
|
||||
return s.listSoraSchedulableAccounts(ctx, groupID)
|
||||
}
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err == nil {
|
||||
@ -2070,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
|
||||
const useMixed = false
|
||||
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", PlatformSora,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform != PlatformSora {
|
||||
continue
|
||||
}
|
||||
if !s.isSoraAccountSchedulable(&acc) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
slog.Debug("account_scheduling_list_sora",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", PlatformSora,
|
||||
"raw_count", len(accounts),
|
||||
"filtered_count", len(filtered))
|
||||
for _, acc := range filtered {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
|
||||
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
|
||||
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
|
||||
@ -2141,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
|
||||
return account.Platform == platform
|
||||
}
|
||||
|
||||
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
|
||||
return s.soraUnschedulableReason(account) == ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
|
||||
if account == nil {
|
||||
return "account_nil"
|
||||
}
|
||||
if account.Status != StatusActive {
|
||||
return fmt.Sprintf("status=%s", account.Status)
|
||||
}
|
||||
if !account.Schedulable {
|
||||
return "schedulable=false"
|
||||
}
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraAccountSchedulable(account)
|
||||
}
|
||||
return account.IsSchedulable()
|
||||
}
|
||||
|
||||
@ -2175,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform == PlatformSora {
|
||||
if !s.isSoraAccountSchedulable(account) {
|
||||
return false
|
||||
}
|
||||
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
|
||||
}
|
||||
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
||||
}
|
||||
|
||||
@ -2795,12 +2673,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@ -2824,7 +2696,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@ -2872,12 +2744,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@ -2983,12 +2849,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@ -3055,12 +2915,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@ -3128,12 +2982,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@ -3203,7 +3051,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@ -3227,7 +3075,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
@ -3240,12 +3087,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@ -3357,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
|
||||
stats.SampleMappingIDs,
|
||||
stats.SampleRateLimitIDs,
|
||||
)
|
||||
if platform == PlatformSora {
|
||||
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
@ -3417,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
detail := "generic_unschedulable"
|
||||
if acc.Platform == PlatformSora {
|
||||
detail = s.soraUnschedulableReason(acc)
|
||||
}
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
||||
}
|
||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||
@ -3444,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
return selectionFailureDiagnosis{Category: "eligible"}
|
||||
}
|
||||
|
||||
func (s *GatewayService) logSoraSelectionFailureDetails(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
accounts []Account,
|
||||
excludedIDs map[int64]struct{},
|
||||
allowMixedScheduling bool,
|
||||
) {
|
||||
const maxLines = 30
|
||||
logged := 0
|
||||
|
||||
for i := range accounts {
|
||||
if logged >= maxLines {
|
||||
break
|
||||
}
|
||||
acc := &accounts[i]
|
||||
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
|
||||
if diagnosis.Category == "eligible" {
|
||||
continue
|
||||
}
|
||||
detail := diagnosis.Detail
|
||||
if detail == "" {
|
||||
detail = "-"
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.gateway",
|
||||
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
|
||||
derefGroupID(groupID),
|
||||
requestedModel,
|
||||
shortSessionHash(sessionHash),
|
||||
acc.ID,
|
||||
acc.Platform,
|
||||
diagnosis.Category,
|
||||
detail,
|
||||
)
|
||||
logged++
|
||||
}
|
||||
if len(accounts) > maxLines {
|
||||
logger.LegacyPrintf(
|
||||
"service.gateway",
|
||||
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
|
||||
derefGroupID(groupID),
|
||||
requestedModel,
|
||||
shortSessionHash(sessionHash),
|
||||
len(accounts),
|
||||
logged,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||
if acc == nil {
|
||||
return true
|
||||
@ -3573,13 +3358,14 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
}
|
||||
return mapAntigravityModel(account, requestedModel) != ""
|
||||
}
|
||||
if account.Platform == PlatformSora {
|
||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
||||
}
|
||||
if account.IsBedrock() {
|
||||
_, ok := ResolveBedrockModelID(account, requestedModel)
|
||||
return ok
|
||||
}
|
||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
||||
return true
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@ -3588,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(requestedModel) == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 先走原始精确/通配符匹配。
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
|
||||
return true
|
||||
}
|
||||
|
||||
aliases := buildSoraModelAliases(requestedModel)
|
||||
if len(aliases) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasSoraSelector := false
|
||||
for pattern := range mapping {
|
||||
if !isSoraModelSelector(pattern) {
|
||||
continue
|
||||
}
|
||||
hasSoraSelector = true
|
||||
if matchPatternAnyAlias(pattern, aliases) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
|
||||
// 此时不应误拦截 Sora 模型请求。
|
||||
if !hasSoraSelector {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func matchPatternAnyAlias(pattern string, aliases []string) bool {
|
||||
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
|
||||
if normalizedPattern == "" {
|
||||
return false
|
||||
}
|
||||
for _, alias := range aliases {
|
||||
if matchWildcard(normalizedPattern, alias) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSoraModelSelector(pattern string) bool {
|
||||
p := strings.ToLower(strings.TrimSpace(pattern))
|
||||
if p == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(p, "sora"),
|
||||
strings.HasPrefix(p, "gpt-image"),
|
||||
strings.HasPrefix(p, "prompt-enhance"),
|
||||
strings.HasPrefix(p, "sy_"):
|
||||
return true
|
||||
}
|
||||
|
||||
return p == "video" || p == "image"
|
||||
}
|
||||
|
||||
func buildSoraModelAliases(requestedModel string) []string {
|
||||
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
aliases := make([]string, 0, 8)
|
||||
addAlias := func(value string) {
|
||||
v := strings.ToLower(strings.TrimSpace(value))
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
for _, existing := range aliases {
|
||||
if existing == v {
|
||||
return
|
||||
}
|
||||
}
|
||||
aliases = append(aliases, v)
|
||||
}
|
||||
|
||||
addAlias(modelID)
|
||||
cfg, ok := GetSoraModelConfig(modelID)
|
||||
if ok {
|
||||
addAlias(cfg.Model)
|
||||
switch cfg.Type {
|
||||
case "video":
|
||||
addAlias("video")
|
||||
addAlias("sora")
|
||||
addAlias(soraVideoFamilyAlias(modelID))
|
||||
case "image":
|
||||
addAlias("image")
|
||||
addAlias("gpt-image")
|
||||
case "prompt_enhance":
|
||||
addAlias("prompt-enhance")
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(modelID, "sora"):
|
||||
addAlias("video")
|
||||
addAlias("sora")
|
||||
addAlias(soraVideoFamilyAlias(modelID))
|
||||
case strings.HasPrefix(modelID, "gpt-image"):
|
||||
addAlias("image")
|
||||
addAlias("gpt-image")
|
||||
case strings.HasPrefix(modelID, "prompt-enhance"):
|
||||
addAlias("prompt-enhance")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return aliases
|
||||
}
|
||||
|
||||
func soraVideoFamilyAlias(modelID string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(modelID, "sora2pro-hd"):
|
||||
return "sora2pro-hd"
|
||||
case strings.HasPrefix(modelID, "sora2pro"):
|
||||
return "sora2pro"
|
||||
case strings.HasPrefix(modelID, "sora2"):
|
||||
return "sora2"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
@ -7434,6 +7083,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ParsedRequest *ParsedRequest
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
@ -7745,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
||||
|
||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||
type recordUsageOpts struct {
|
||||
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||
// ParsedRequest(可选,仅 Claude 路径传入)
|
||||
ParsedRequest *ParsedRequest
|
||||
|
||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||
// - Claude Max 缓存计费策略
|
||||
// - Sora 媒体类型分支(image/video/prompt)
|
||||
// - MediaType 字段写入使用日志
|
||||
EnableClaudePath bool
|
||||
|
||||
@ -7776,6 +7424,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
APIKeyService: input.APIKeyService,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
ParsedRequest: input.ParsedRequest,
|
||||
EnableClaudePath: true,
|
||||
})
|
||||
}
|
||||
@ -7841,8 +7490,6 @@ type recordUsageCoreInput struct {
|
||||
|
||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||
// opts 中的字段控制两者之间的差异行为:
|
||||
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
|
||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||
result := input.Result
|
||||
@ -7944,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
multiplier float64,
|
||||
opts *recordUsageOpts,
|
||||
) *CostBreakdown {
|
||||
// Sora 媒体类型分支(仅 Claude 路径启用)
|
||||
if opts.EnableClaudePath {
|
||||
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
||||
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
|
||||
}
|
||||
if result.MediaType == MediaTypePrompt {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
}
|
||||
|
||||
// 图片生成计费
|
||||
if result.ImageCount > 0 {
|
||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||
@ -7963,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost(
|
||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||
}
|
||||
|
||||
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
|
||||
func (s *GatewayService) calculateSoraMediaCost(
|
||||
result *ForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
) *CostBreakdown {
|
||||
var soraConfig *SoraPriceConfig
|
||||
if apiKey.Group != nil {
|
||||
soraConfig = &SoraPriceConfig{
|
||||
ImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
ImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
}
|
||||
}
|
||||
if result.MediaType == MediaTypeImage {
|
||||
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
}
|
||||
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
|
||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
@ -8163,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
}
|
||||
|
||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
|
||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||
isSoraMedia := opts.EnableClaudePath &&
|
||||
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
||||
if isSoraMedia {
|
||||
return nil
|
||||
}
|
||||
var mode string
|
||||
switch {
|
||||
case cost != nil && cost.BillingMode != "":
|
||||
@ -8183,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
||||
}
|
||||
|
||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||
return &result.MediaType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -8293,6 +7899,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
|
||||
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
|
||||
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
|
||||
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
|
||||
if groupID == nil {
|
||||
return false
|
||||
}
|
||||
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
|
||||
return false
|
||||
}
|
||||
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
@ -9,35 +9,35 @@ import (
|
||||
|
||||
func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
model := "gpt-5.4"
|
||||
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
accounts := []Account{
|
||||
// excluded
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// unschedulable
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
},
|
||||
// platform filtered
|
||||
{
|
||||
ID: 3,
|
||||
Platform: PlatformOpenAI,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// model unsupported
|
||||
{
|
||||
ID: 4,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
// model rate limited
|
||||
{
|
||||
ID: 5,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
// eligible
|
||||
{
|
||||
ID: 6,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false)
|
||||
|
||||
if stats.Total != 6 {
|
||||
t.Fatalf("total=%d want=6", stats.Total)
|
||||
@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
|
||||
func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
acc := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "unschedulable" {
|
||||
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
|
||||
}
|
||||
if diagnosis.Detail != "schedulable=false" {
|
||||
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
|
||||
if diagnosis.Detail != "generic_unschedulable" {
|
||||
t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
||||
func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
model := "gpt-5.4"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
acc := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformSora,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "model_rate_limited" {
|
||||
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
|
||||
}
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when model_mapping is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-4o": "gpt-4o",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sora2": "sora2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
|
||||
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sy_8": "sy_8",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-image": "gpt-image",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
|
||||
}
|
||||
}
|
||||
@ -1,89 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
now := time.Now()
|
||||
past := now.Add(-1 * time.Minute)
|
||||
future := now.Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &past,
|
||||
OverloadUntil: &future,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if !svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected non-sora account to keep generic schedulable checks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
globalResetAt := time.Now().Add(2 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &globalResetAt,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
|
||||
t.Fatalf("expected sora account to be blocked by model scope rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(3 * time.Minute)
|
||||
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
},
|
||||
}
|
||||
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
if stats.Unschedulable != 0 || stats.Eligible != 1 {
|
||||
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
|
||||
}
|
||||
}
|
||||
@ -26,15 +26,6 @@ type Group struct {
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Sora 按次计费配置(阶段 1)
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
}
|
||||
}
|
||||
|
||||
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
|
||||
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
|
||||
switch imageSize {
|
||||
case "360":
|
||||
return g.SoraImagePrice360
|
||||
case "540":
|
||||
return g.SoraImagePrice540
|
||||
default:
|
||||
return g.SoraImagePrice360
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
|
||||
@ -3,30 +3,15 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
|
||||
|
||||
type soraSessionChunk struct {
|
||||
index int
|
||||
value string
|
||||
}
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
||||
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
|
||||
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
|
||||
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
|
||||
}
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
|
||||
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var sessionResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
Expires string `json:"expires"`
|
||||
User struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||
if strings.TrimSpace(sessionResp.Expires) != "" {
|
||||
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
||||
expiresAt = parsed.Unix()
|
||||
}
|
||||
}
|
||||
expiresIn := expiresAt - time.Now().Unix()
|
||||
if expiresIn < 0 {
|
||||
expiresIn = 0
|
||||
}
|
||||
|
||||
return &OpenAITokenInfo{
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
ClientID: openai.SoraClientID,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeSoraSessionTokenInput(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
|
||||
if len(matches) == 0 {
|
||||
return sanitizeSessionToken(trimmed)
|
||||
}
|
||||
|
||||
chunkMatches := make([]soraSessionChunk, 0, len(matches))
|
||||
singleValues := make([]string, 0, len(matches))
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := sanitizeSessionToken(match[2])
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(match[1]) == "" {
|
||||
singleValues = append(singleValues, value)
|
||||
continue
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
|
||||
if err != nil || idx < 0 {
|
||||
continue
|
||||
}
|
||||
chunkMatches = append(chunkMatches, soraSessionChunk{
|
||||
index: idx,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
|
||||
return merged
|
||||
}
|
||||
|
||||
if len(singleValues) > 0 {
|
||||
return singleValues[len(singleValues)-1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
byIndex := make(map[int]string, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
byIndex[chunk.index] = chunk.value
|
||||
}
|
||||
|
||||
if _, ok := byIndex[0]; !ok {
|
||||
return ""
|
||||
}
|
||||
if requireComplete {
|
||||
for idx := 0; idx <= requiredMaxIndex; idx++ {
|
||||
if _, ok := byIndex[idx]; !ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orderedIndexes := make([]int, 0, len(byIndex))
|
||||
for idx := range byIndex {
|
||||
orderedIndexes = append(orderedIndexes, idx)
|
||||
}
|
||||
sort.Ints(orderedIndexes)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, idx := range orderedIndexes {
|
||||
if _, err := builder.WriteString(byIndex[idx]); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return sanitizeSessionToken(builder.String())
|
||||
}
|
||||
|
||||
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requiredMaxIndex := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.index > requiredMaxIndex {
|
||||
requiredMaxIndex = chunk.index
|
||||
}
|
||||
}
|
||||
|
||||
groupStarts := make([]int, 0, len(chunks))
|
||||
for idx, chunk := range chunks {
|
||||
if chunk.index == 0 {
|
||||
groupStarts = append(groupStarts, idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupStarts) == 0 {
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
for i := len(groupStarts) - 1; i >= 0; i-- {
|
||||
start := groupStarts[i]
|
||||
end := len(chunks)
|
||||
if i+1 < len(groupStarts) {
|
||||
end = groupStarts[i+1]
|
||||
}
|
||||
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
|
||||
return merged
|
||||
}
|
||||
}
|
||||
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
func sanitizeSessionToken(raw string) string {
|
||||
token := strings.TrimSpace(raw)
|
||||
token = strings.Trim(token, "\"'`")
|
||||
token = strings.TrimSuffix(token, ";")
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
// RefreshAccountToken refreshes token for an OpenAI OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
||||
if account.Platform != PlatformOpenAI {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||
@ -609,10 +389,5 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case PlatformSora:
|
||||
return openai.OAuthPlatformSora
|
||||
default:
|
||||
return openai.OAuthPlatformOpenAI
|
||||
}
|
||||
return openai.OAuthPlatformOpenAI
|
||||
}
|
||||
|
||||
@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
|
||||
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
|
||||
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
|
||||
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.AuthURL)
|
||||
require.NotEmpty(t, result.SessionID)
|
||||
|
||||
parsed, err := url.Parse(result.AuthURL)
|
||||
require.NoError(t, err)
|
||||
q := parsed.Query()
|
||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
||||
require.Empty(t, q.Get("codex_cli_simplified_flow"))
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientNoopStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
require.Equal(t, "demo@example.com", info.Email)
|
||||
require.Greater(t, info.ExpiresAt, int64(0))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
||||
// OpenAITokenCache token cache interface.
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
|
||||
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
|
||||
type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai/sora oauth account")
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
p.metrics.refreshRequests.Add(1)
|
||||
p.metrics.touchNow()
|
||||
|
||||
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
||||
p.metrics.lockContention.Add(1)
|
||||
p.metrics.touchNow()
|
||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||
if waitErr != nil {
|
||||
return "", waitErr
|
||||
}
|
||||
if strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else if result.Refreshed {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
||||
p.metrics.lockContention.Add(1)
|
||||
p.metrics.touchNow()
|
||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||
if waitErr != nil {
|
||||
return "", waitErr
|
||||
}
|
||||
if strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else if result.Refreshed {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
|
||||
@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
|
||||
@ -22,8 +22,6 @@ import (
|
||||
var (
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
|
||||
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
|
||||
"default subscription group must exist and be subscription type",
|
||||
@ -104,7 +102,6 @@ type SettingService struct {
|
||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
}
|
||||
|
||||
@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyHideCcsImportButton,
|
||||
SettingKeyPurchaseSubscriptionEnabled,
|
||||
SettingKeyPurchaseSubscriptionURL,
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyCustomMenuItems,
|
||||
SettingKeyCustomEndpoints,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||
s.onUpdate = callback
|
||||
}
|
||||
|
||||
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
|
||||
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
|
||||
s.onS3Update = callback
|
||||
}
|
||||
|
||||
// SetVersion sets the application version for injection into public settings
|
||||
func (s *SettingService) SetVersion(version string) {
|
||||
s.version = version
|
||||
@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
|
||||
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
||||
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
|
||||
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
||||
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
|
||||
|
||||
@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeySoraClientEnabled: "false",
|
||||
SettingKeyCustomMenuItems: "[]",
|
||||
SettingKeyCustomEndpoints: "[]",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
@ -1584,606 +1569,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
|
||||
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
|
||||
}
|
||||
|
||||
type soraS3ProfilesStore struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []soraS3ProfileStoreItem `json:"items"`
|
||||
}
|
||||
|
||||
type soraS3ProfileStoreItem struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
|
||||
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
||||
profiles, err := s.ListSoraS3Profiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
||||
if activeProfile == nil {
|
||||
return &SoraS3Settings{}, nil
|
||||
}
|
||||
|
||||
return &SoraS3Settings{
|
||||
Enabled: activeProfile.Enabled,
|
||||
Endpoint: activeProfile.Endpoint,
|
||||
Region: activeProfile.Region,
|
||||
Bucket: activeProfile.Bucket,
|
||||
AccessKeyID: activeProfile.AccessKeyID,
|
||||
SecretAccessKey: activeProfile.SecretAccessKey,
|
||||
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
|
||||
Prefix: activeProfile.Prefix,
|
||||
ForcePathStyle: activeProfile.ForcePathStyle,
|
||||
CDNURL: activeProfile.CDNURL,
|
||||
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
|
||||
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
|
||||
if activeIndex < 0 {
|
||||
activeID := "default"
|
||||
if hasSoraS3ProfileID(store.Items, activeID) {
|
||||
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
|
||||
}
|
||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
||||
ProfileID: activeID,
|
||||
Name: "Default",
|
||||
UpdatedAt: now,
|
||||
})
|
||||
store.ActiveProfileID = activeID
|
||||
activeIndex = len(store.Items) - 1
|
||||
}
|
||||
|
||||
active := store.Items[activeIndex]
|
||||
active.Enabled = settings.Enabled
|
||||
active.Endpoint = strings.TrimSpace(settings.Endpoint)
|
||||
active.Region = strings.TrimSpace(settings.Region)
|
||||
active.Bucket = strings.TrimSpace(settings.Bucket)
|
||||
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
|
||||
active.Prefix = strings.TrimSpace(settings.Prefix)
|
||||
active.ForcePathStyle = settings.ForcePathStyle
|
||||
active.CDNURL = strings.TrimSpace(settings.CDNURL)
|
||||
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
|
||||
if settings.SecretAccessKey != "" {
|
||||
active.SecretAccessKey = settings.SecretAccessKey
|
||||
}
|
||||
active.UpdatedAt = now
|
||||
store.Items[activeIndex] = active
|
||||
|
||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置列表
|
||||
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertSoraS3ProfilesStore(store), nil
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
|
||||
if profile == nil {
|
||||
return nil, fmt.Errorf("profile cannot be nil")
|
||||
}
|
||||
|
||||
profileID := strings.TrimSpace(profile.ProfileID)
|
||||
if profileID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasSoraS3ProfileID(store.Items, profileID) {
|
||||
return nil, ErrSoraS3ProfileExists
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
||||
ProfileID: profileID,
|
||||
Name: name,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: strings.TrimSpace(profile.Endpoint),
|
||||
Region: strings.TrimSpace(profile.Region),
|
||||
Bucket: strings.TrimSpace(profile.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
|
||||
SecretAccessKey: profile.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(profile.Prefix),
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(profile.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
|
||||
if setActive || store.ActiveProfileID == "" {
|
||||
store.ActiveProfileID = profileID
|
||||
}
|
||||
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
created := findSoraS3ProfileByID(profiles.Items, profileID)
|
||||
if created == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
|
||||
if profile == nil {
|
||||
return nil, fmt.Errorf("profile cannot be nil")
|
||||
}
|
||||
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
target := store.Items[targetIndex]
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
||||
}
|
||||
target.Name = name
|
||||
target.Enabled = profile.Enabled
|
||||
target.Endpoint = strings.TrimSpace(profile.Endpoint)
|
||||
target.Region = strings.TrimSpace(profile.Region)
|
||||
target.Bucket = strings.TrimSpace(profile.Bucket)
|
||||
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
|
||||
target.Prefix = strings.TrimSpace(profile.Prefix)
|
||||
target.ForcePathStyle = profile.ForcePathStyle
|
||||
target.CDNURL = strings.TrimSpace(profile.CDNURL)
|
||||
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
|
||||
if profile.SecretAccessKey != "" {
|
||||
target.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
store.Items[targetIndex] = target
|
||||
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
updated := findSoraS3ProfileByID(profiles.Items, targetID)
|
||||
if updated == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
|
||||
if store.ActiveProfileID == targetID {
|
||||
store.ActiveProfileID = ""
|
||||
if len(store.Items) > 0 {
|
||||
store.ActiveProfileID = store.Items[0].ProfileID
|
||||
}
|
||||
}
|
||||
|
||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
|
||||
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
store.ActiveProfileID = targetID
|
||||
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
||||
if active == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
|
||||
if err == nil {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
var store soraS3ProfilesStore
|
||||
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
|
||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
||||
if legacyErr != nil {
|
||||
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
|
||||
}
|
||||
if isEmptyLegacySoraS3Settings(legacy) {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
return &soraS3ProfilesStore{
|
||||
ActiveProfileID: "default",
|
||||
Items: []soraS3ProfileStoreItem{
|
||||
{
|
||||
ProfileID: "default",
|
||||
Name: "Default",
|
||||
Enabled: legacy.Enabled,
|
||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
||||
Region: strings.TrimSpace(legacy.Region),
|
||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
||||
SecretAccessKey: legacy.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
||||
ForcePathStyle: legacy.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
normalized := normalizeSoraS3ProfilesStore(store)
|
||||
return &normalized, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, ErrSettingNotFound) {
|
||||
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
|
||||
}
|
||||
|
||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
||||
if legacyErr != nil {
|
||||
return nil, legacyErr
|
||||
}
|
||||
if isEmptyLegacySoraS3Settings(legacy) {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
return &soraS3ProfilesStore{
|
||||
ActiveProfileID: "default",
|
||||
Items: []soraS3ProfileStoreItem{
|
||||
{
|
||||
ProfileID: "default",
|
||||
Name: "Default",
|
||||
Enabled: legacy.Enabled,
|
||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
||||
Region: strings.TrimSpace(legacy.Region),
|
||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
||||
SecretAccessKey: legacy.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
||||
ForcePathStyle: legacy.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
|
||||
if store == nil {
|
||||
return fmt.Errorf("sora s3 profiles store cannot be nil")
|
||||
}
|
||||
|
||||
normalized := normalizeSoraS3ProfilesStore(*store)
|
||||
data, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sora s3 profiles: %w", err)
|
||||
}
|
||||
|
||||
updates := map[string]string{
|
||||
SettingKeySoraS3Profiles: string(data),
|
||||
}
|
||||
|
||||
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
|
||||
if active == nil {
|
||||
updates[SettingKeySoraS3Enabled] = "false"
|
||||
updates[SettingKeySoraS3Endpoint] = ""
|
||||
updates[SettingKeySoraS3Region] = ""
|
||||
updates[SettingKeySoraS3Bucket] = ""
|
||||
updates[SettingKeySoraS3AccessKeyID] = ""
|
||||
updates[SettingKeySoraS3Prefix] = ""
|
||||
updates[SettingKeySoraS3ForcePathStyle] = "false"
|
||||
updates[SettingKeySoraS3CDNURL] = ""
|
||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
|
||||
updates[SettingKeySoraS3SecretAccessKey] = ""
|
||||
} else {
|
||||
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
|
||||
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
|
||||
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
|
||||
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
|
||||
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
|
||||
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
|
||||
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
|
||||
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
|
||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
|
||||
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
|
||||
}
|
||||
|
||||
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate()
|
||||
}
|
||||
if s.onS3Update != nil {
|
||||
s.onS3Update()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
||||
keys := []string{
|
||||
SettingKeySoraS3Enabled,
|
||||
SettingKeySoraS3Endpoint,
|
||||
SettingKeySoraS3Region,
|
||||
SettingKeySoraS3Bucket,
|
||||
SettingKeySoraS3AccessKeyID,
|
||||
SettingKeySoraS3SecretAccessKey,
|
||||
SettingKeySoraS3Prefix,
|
||||
SettingKeySoraS3ForcePathStyle,
|
||||
SettingKeySoraS3CDNURL,
|
||||
SettingKeySoraDefaultStorageQuotaBytes,
|
||||
}
|
||||
|
||||
values, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
|
||||
}
|
||||
|
||||
result := &SoraS3Settings{
|
||||
Enabled: values[SettingKeySoraS3Enabled] == "true",
|
||||
Endpoint: values[SettingKeySoraS3Endpoint],
|
||||
Region: values[SettingKeySoraS3Region],
|
||||
Bucket: values[SettingKeySoraS3Bucket],
|
||||
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
|
||||
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
|
||||
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
|
||||
Prefix: values[SettingKeySoraS3Prefix],
|
||||
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
|
||||
CDNURL: values[SettingKeySoraS3CDNURL],
|
||||
}
|
||||
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
|
||||
result.DefaultStorageQuotaBytes = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
|
||||
seen := make(map[string]struct{}, len(store.Items))
|
||||
normalized := soraS3ProfilesStore{
|
||||
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
|
||||
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
for idx := range store.Items {
|
||||
item := store.Items[idx]
|
||||
item.ProfileID = strings.TrimSpace(item.ProfileID)
|
||||
if item.ProfileID == "" {
|
||||
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
|
||||
}
|
||||
if _, exists := seen[item.ProfileID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.ProfileID] = struct{}{}
|
||||
|
||||
item.Name = strings.TrimSpace(item.Name)
|
||||
if item.Name == "" {
|
||||
item.Name = item.ProfileID
|
||||
}
|
||||
item.Endpoint = strings.TrimSpace(item.Endpoint)
|
||||
item.Region = strings.TrimSpace(item.Region)
|
||||
item.Bucket = strings.TrimSpace(item.Bucket)
|
||||
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
|
||||
item.Prefix = strings.TrimSpace(item.Prefix)
|
||||
item.CDNURL = strings.TrimSpace(item.CDNURL)
|
||||
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
|
||||
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
|
||||
if item.UpdatedAt == "" {
|
||||
item.UpdatedAt = now
|
||||
}
|
||||
normalized.Items = append(normalized.Items, item)
|
||||
}
|
||||
|
||||
if len(normalized.Items) == 0 {
|
||||
normalized.ActiveProfileID = ""
|
||||
return normalized
|
||||
}
|
||||
|
||||
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
|
||||
return normalized
|
||||
}
|
||||
|
||||
normalized.ActiveProfileID = normalized.Items[0].ProfileID
|
||||
return normalized
|
||||
}
|
||||
|
||||
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
|
||||
if store == nil {
|
||||
return &SoraS3ProfileList{}
|
||||
}
|
||||
items := make([]SoraS3Profile, 0, len(store.Items))
|
||||
for idx := range store.Items {
|
||||
item := store.Items[idx]
|
||||
items = append(items, SoraS3Profile{
|
||||
ProfileID: item.ProfileID,
|
||||
Name: item.Name,
|
||||
IsActive: item.ProfileID == store.ActiveProfileID,
|
||||
Enabled: item.Enabled,
|
||||
Endpoint: item.Endpoint,
|
||||
Region: item.Region,
|
||||
Bucket: item.Bucket,
|
||||
AccessKeyID: item.AccessKeyID,
|
||||
SecretAccessKey: item.SecretAccessKey,
|
||||
SecretAccessKeyConfigured: item.SecretAccessKey != "",
|
||||
Prefix: item.Prefix,
|
||||
ForcePathStyle: item.ForcePathStyle,
|
||||
CDNURL: item.CDNURL,
|
||||
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
})
|
||||
}
|
||||
return &SoraS3ProfileList{
|
||||
ActiveProfileID: store.ActiveProfileID,
|
||||
Items: items,
|
||||
}
|
||||
}
|
||||
|
||||
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == activeProfileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == activeProfileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
|
||||
return findSoraS3ProfileIndex(items, profileID) >= 0
|
||||
}
|
||||
|
||||
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
|
||||
if settings == nil {
|
||||
return true
|
||||
}
|
||||
if settings.Enabled {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Endpoint) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Region) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Bucket) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.AccessKeyID) != "" {
|
||||
return false
|
||||
}
|
||||
if settings.SecretAccessKey != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Prefix) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.CDNURL) != "" {
|
||||
return false
|
||||
}
|
||||
return settings.DefaultStorageQuotaBytes == 0
|
||||
}
|
||||
|
||||
func maxInt64(value int64, min int64) int64 {
|
||||
if value < min {
|
||||
return min
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@ -41,7 +41,6 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool
|
||||
PurchaseSubscriptionEnabled bool
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
@ -107,7 +106,6 @@ type PublicSettings struct {
|
||||
|
||||
PurchaseSubscriptionEnabled bool
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
@ -116,46 +114,6 @@ type PublicSettings struct {
|
||||
Version string
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 多配置项(服务内部模型)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// SoraS3ProfileList Sora S3 多配置列表
|
||||
type SoraS3ProfileList struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
|
||||
type StreamTimeoutSettings struct {
|
||||
// Enabled 是否启用流超时处理
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// SoraAccountRepository Sora 账号扩展表仓储接口
|
||||
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
|
||||
//
|
||||
// 设计说明:
|
||||
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
|
||||
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
|
||||
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
|
||||
// - Token 刷新时需要同时更新两个表以保持数据一致性
|
||||
type SoraAccountRepository interface {
|
||||
// Upsert 创建或更新 Sora 账号扩展信息
|
||||
// accountID: 关联的 accounts.id
|
||||
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
|
||||
//
|
||||
// 如果记录不存在则创建,存在则更新。
|
||||
// 用于:
|
||||
// 1. 创建 Sora 账号时初始化扩展表
|
||||
// 2. Token 刷新时同步更新扩展表
|
||||
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
|
||||
|
||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
||||
// 返回 nil, nil 表示记录不存在(非错误)
|
||||
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
|
||||
|
||||
// Delete 删除 Sora 账号扩展信息
|
||||
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
|
||||
Delete(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// SoraAccount Sora 账号扩展信息
|
||||
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
|
||||
type SoraAccount struct {
|
||||
AccountID int64 // 关联的 accounts.id
|
||||
AccessToken string // OAuth access_token
|
||||
RefreshToken string // OAuth refresh_token
|
||||
SessionToken string // Session token(可选,用于 ST→AT 兜底)
|
||||
}
|
||||
@ -1,117 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SoraClient 定义直连 Sora 的任务操作接口。
|
||||
type SoraClient interface {
|
||||
Enabled() bool
|
||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
|
||||
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
|
||||
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
|
||||
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
|
||||
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
|
||||
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
|
||||
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
|
||||
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
|
||||
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
|
||||
DeletePost(ctx context.Context, account *Account, postID string) error
|
||||
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
|
||||
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||
}
|
||||
|
||||
// SoraImageRequest 图片生成请求参数
|
||||
type SoraImageRequest struct {
|
||||
Prompt string
|
||||
Width int
|
||||
Height int
|
||||
MediaID string
|
||||
}
|
||||
|
||||
// SoraVideoRequest 视频生成请求参数
|
||||
type SoraVideoRequest struct {
|
||||
Prompt string
|
||||
Orientation string
|
||||
Frames int
|
||||
Model string
|
||||
Size string
|
||||
VideoCount int
|
||||
MediaID string
|
||||
RemixTargetID string
|
||||
CameoIDs []string
|
||||
}
|
||||
|
||||
// SoraStoryboardRequest 分镜视频生成请求参数
|
||||
type SoraStoryboardRequest struct {
|
||||
Prompt string
|
||||
Orientation string
|
||||
Frames int
|
||||
Model string
|
||||
Size string
|
||||
MediaID string
|
||||
}
|
||||
|
||||
// SoraImageTaskStatus 图片任务状态
|
||||
type SoraImageTaskStatus struct {
|
||||
ID string
|
||||
Status string
|
||||
ProgressPct float64
|
||||
URLs []string
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
// SoraVideoTaskStatus 视频任务状态
|
||||
type SoraVideoTaskStatus struct {
|
||||
ID string
|
||||
Status string
|
||||
ProgressPct int
|
||||
URLs []string
|
||||
GenerationID string
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
// SoraCameoStatus 角色处理中间态
|
||||
type SoraCameoStatus struct {
|
||||
Status string
|
||||
StatusMessage string
|
||||
DisplayNameHint string
|
||||
UsernameHint string
|
||||
ProfileAssetURL string
|
||||
InstructionSetHint any
|
||||
InstructionSet any
|
||||
}
|
||||
|
||||
// SoraCharacterFinalizeRequest 角色定稿请求参数
|
||||
type SoraCharacterFinalizeRequest struct {
|
||||
CameoID string
|
||||
Username string
|
||||
DisplayName string
|
||||
ProfileAssetPointer string
|
||||
InstructionSet any
|
||||
}
|
||||
|
||||
// SoraUpstreamError 上游错误
|
||||
type SoraUpstreamError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (e *SoraUpstreamError) Error() string {
|
||||
if e == nil {
|
||||
return "sora upstream error"
|
||||
}
|
||||
if e.Message != "" {
|
||||
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,564 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ SoraClient = (*stubSoraClientForPoll)(nil)
|
||||
|
||||
type stubSoraClientForPoll struct {
|
||||
imageStatus *SoraImageTaskStatus
|
||||
videoStatus *SoraVideoTaskStatus
|
||||
imageCalls int
|
||||
videoCalls int
|
||||
enhanced string
|
||||
enhanceErr error
|
||||
storyboard bool
|
||||
videoReq SoraVideoRequest
|
||||
parseErr error
|
||||
postCalls int
|
||||
deleteCalls int
|
||||
}
|
||||
|
||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
s.videoReq = req
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
s.storyboard = true
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
return &SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
s.postCalls++
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
s.deleteCalls++
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
if s.parseErr != nil {
|
||||
return "", s.parseErr
|
||||
}
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
if s.enhanced != "" {
|
||||
return s.enhanced, s.enhanceErr
|
||||
}
|
||||
return "enhanced prompt", s.enhanceErr
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
s.imageCalls++
|
||||
return s.imageStatus, nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||
s.videoCalls++
|
||||
return s.videoStatus, nil
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
imageStatus: &SoraImageTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/a.png"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
|
||||
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"https://example.com/a.png"}, urls)
|
||||
require.Equal(t, 1, client.imageCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
enhanced: "cinematic prompt",
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
|
||||
},
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/v.mp4"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, client.storyboard)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/v.mp4"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 3, client.videoReq.VideoCount)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, 0, client.videoCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
parseErr: errors.New("parse failed"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 0, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 1, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "failed",
|
||||
ErrorMsg: "reject",
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
|
||||
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, status)
|
||||
require.Contains(t, err.Error(), "reject")
|
||||
require.Equal(t, 1, client.videoCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
SoraMediaSigningKey: "test-key",
|
||||
SoraMediaSignedURLTTLSeconds: 600,
|
||||
},
|
||||
}
|
||||
service := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
|
||||
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
|
||||
require.Contains(t, url, "/sora/media-signed")
|
||||
require.Contains(t, url, "expires=")
|
||||
require.Contains(t, url, "sig=")
|
||||
}
|
||||
|
||||
func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
result := svc.normalizeSoraMediaURLs(nil)
|
||||
require.Empty(t, result)
|
||||
|
||||
result = svc.normalizeSoraMediaURLs([]string{})
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
|
||||
result := svc.normalizeSoraMediaURLs(urls)
|
||||
require.Equal(t, urls, result)
|
||||
}
|
||||
|
||||
func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
|
||||
result := svc.normalizeSoraMediaURLs(urls)
|
||||
require.Len(t, result, 2)
|
||||
require.Contains(t, result[0], "/sora/media")
|
||||
require.Contains(t, result[1], "/sora/media")
|
||||
}
|
||||
|
||||
func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
|
||||
result := svc.normalizeSoraMediaURLs(urls)
|
||||
require.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestBuildSoraContent_Image(t *testing.T) {
|
||||
content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
|
||||
require.Contains(t, content, "")
|
||||
require.Contains(t, content, "")
|
||||
}
|
||||
|
||||
func TestBuildSoraContent_Video(t *testing.T) {
|
||||
content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
|
||||
require.Contains(t, content, "<video src='https://a.com/v.mp4'")
|
||||
}
|
||||
|
||||
func TestBuildSoraContent_VideoEmpty(t *testing.T) {
|
||||
content := buildSoraContent("video", nil)
|
||||
require.Empty(t, content)
|
||||
}
|
||||
|
||||
func TestBuildSoraContent_Prompt(t *testing.T) {
|
||||
content := buildSoraContent("prompt", nil)
|
||||
require.Empty(t, content)
|
||||
}
|
||||
|
||||
func TestSoraImageSizeFromModel(t *testing.T) {
|
||||
require.Equal(t, "360", soraImageSizeFromModel("gpt-image"))
|
||||
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-landscape"))
|
||||
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-portrait"))
|
||||
require.Equal(t, "540", soraImageSizeFromModel("something-landscape"))
|
||||
require.Equal(t, "360", soraImageSizeFromModel("unknown-model"))
|
||||
}
|
||||
|
||||
func TestFirstMediaURL(t *testing.T) {
|
||||
require.Equal(t, "", firstMediaURL(nil))
|
||||
require.Equal(t, "", firstMediaURL([]string{}))
|
||||
require.Equal(t, "a", firstMediaURL([]string{"a", "b"}))
|
||||
}
|
||||
|
||||
func TestSoraProErrorMessage(t *testing.T) {
|
||||
require.Contains(t, soraProErrorMessage("sora2pro-hd", ""), "Pro-HD")
|
||||
require.Contains(t, soraProErrorMessage("sora2pro", ""), "Pro")
|
||||
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: error\n")
|
||||
require.Contains(t, body, "data: [DONE]\n\n")
|
||||
|
||||
lines := strings.Split(body, "\n")
|
||||
require.GreaterOrEqual(t, len(lines), 2)
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
||||
|
||||
data := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
|
||||
errObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
sourceHeaders := http.Header{}
|
||||
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
|
||||
err := svc.handleSoraRequestError(
|
||||
context.Background(),
|
||||
&Account{ID: 1, Platform: PlatformSora},
|
||||
&SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "forbidden",
|
||||
Headers: sourceHeaders,
|
||||
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
|
||||
},
|
||||
"sora2-landscape-10s",
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.NotNil(t, failoverErr.ResponseHeaders)
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||
|
||||
sourceHeaders.Set("cf-ray", "mutated-after-return")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
||||
}
|
||||
|
||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(404))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
||||
require.False(t, svc.shouldFailoverUpstreamError(200))
|
||||
require.False(t, svc.shouldFailoverUpstreamError(400))
|
||||
}
|
||||
|
||||
func TestWithSoraTimeout_NilService(t *testing.T) {
|
||||
var svc *SoraGatewayService
|
||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
||||
require.NotNil(t, ctx)
|
||||
require.Nil(t, cancel)
|
||||
}
|
||||
|
||||
func TestWithSoraTimeout_ZeroTimeout(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
||||
require.NotNil(t, ctx)
|
||||
require.Nil(t, cancel)
|
||||
}
|
||||
|
||||
func TestWithSoraTimeout_PositiveTimeout(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
SoraRequestTimeoutSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
||||
require.NotNil(t, ctx)
|
||||
require.NotNil(t, cancel)
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestPollInterval(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
require.Equal(t, 5*time.Second, svc.pollInterval())
|
||||
|
||||
// 默认值
|
||||
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
require.True(t, svc2.pollInterval() > 0)
|
||||
}
|
||||
|
||||
func TestPollMaxAttempts(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
MaxPollAttempts: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||
require.Equal(t, 100, svc.pollMaxAttempts())
|
||||
|
||||
// 默认值
|
||||
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
require.True(t, svc2.pollMaxAttempts() > 0)
|
||||
}
|
||||
|
||||
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
|
||||
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
||||
encoded := "data:image/png;base64,aGVsbG8="
|
||||
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, data)
|
||||
require.Contains(t, filename, ".png")
|
||||
}
|
||||
|
||||
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
|
||||
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
}
|
||||
|
||||
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"watermark_free": float64(1),
|
||||
"watermark_fallback_on_failure": float64(0),
|
||||
}
|
||||
opts := parseSoraWatermarkOptions(body)
|
||||
require.True(t, opts.Enabled)
|
||||
require.False(t, opts.FallbackOnFailure)
|
||||
}
|
||||
|
||||
func TestParseSoraVideoCount(t *testing.T) {
|
||||
require.Equal(t, 1, parseSoraVideoCount(nil))
|
||||
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
|
||||
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
|
||||
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
|
||||
}
|
||||
@ -1,532 +0,0 @@
|
||||
//nolint:unused
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||
|
||||
const soraRewriteBufferLimit = 2048
|
||||
|
||||
type soraStreamingResult struct {
|
||||
mediaType string
|
||||
mediaURLs []string
|
||||
imageCount int
|
||||
imageSize string
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
if c != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
if s.rateLimitService == nil || account == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
|
||||
upstreamMsg = msg
|
||||
}
|
||||
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
if c != nil {
|
||||
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
|
||||
c.JSON(resp.StatusCode, responsePayload)
|
||||
}
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
|
||||
if len(respBody) > 0 {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(respBody, &payload); err == nil {
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if overrideMessage != "" {
|
||||
errObj["message"] = overrideMessage
|
||||
}
|
||||
payload["error"] = errObj
|
||||
return payload
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": overrideMessage,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
|
||||
if resp == nil {
|
||||
return nil, errors.New("empty response")
|
||||
}
|
||||
|
||||
if clientStream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
contentBuilder := strings.Builder{}
|
||||
var firstTokenMs *int
|
||||
var upstreamError error
|
||||
rewriteBuffer := ""
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
sendLine := func(line string) error {
|
||||
if !clientStream {
|
||||
return nil
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if soraSSEDataRe.MatchString(line) {
|
||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "[DONE]" {
|
||||
if rewriteBuffer != "" {
|
||||
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if flushLine != "" {
|
||||
if flushContent != "" {
|
||||
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := sendLine(flushLine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rewriteBuffer = ""
|
||||
}
|
||||
if err := sendLine("data: [DONE]"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
||||
if errEvent != nil && upstreamError == nil {
|
||||
upstreamError = errEvent
|
||||
}
|
||||
if contentDelta != "" {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := sendLine(updatedLine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := sendLine(line); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
if clientStream {
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
if clientStream {
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := contentBuilder.String()
|
||||
mediaType, mediaURLs := s.extractSoraMedia(content)
|
||||
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
|
||||
mediaType = "prompt"
|
||||
}
|
||||
imageSize := ""
|
||||
imageCount := 0
|
||||
if mediaType == "image" {
|
||||
imageSize = soraImageSizeFromModel(originalModel)
|
||||
imageCount = len(mediaURLs)
|
||||
}
|
||||
|
||||
if upstreamError != nil && !clientStream {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusBadGateway, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": upstreamError.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, upstreamError
|
||||
}
|
||||
|
||||
if !clientStream {
|
||||
response := buildSoraNonStreamResponse(content, originalModel)
|
||||
if len(mediaURLs) > 0 {
|
||||
response["media_url"] = mediaURLs[0]
|
||||
if len(mediaURLs) > 1 {
|
||||
response["media_urls"] = mediaURLs
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
return &soraStreamingResult{
|
||||
mediaType: mediaType,
|
||||
mediaURLs: mediaURLs,
|
||||
imageCount: imageCount,
|
||||
imageSize: imageSize,
|
||||
firstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
||||
if strings.TrimSpace(data) == "" {
|
||||
return "data: ", "", nil
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
return "data: " + data, "", nil
|
||||
}
|
||||
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return "data: " + data, "", errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
|
||||
payload["model"] = originalModel
|
||||
}
|
||||
|
||||
contentDelta, updated := extractSoraContent(payload)
|
||||
if updated {
|
||||
var rewritten string
|
||||
if rewriteBuffer != nil {
|
||||
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
||||
} else {
|
||||
rewritten = s.rewriteSoraContent(contentDelta)
|
||||
}
|
||||
if rewritten != contentDelta {
|
||||
applySoraContent(payload, rewritten)
|
||||
contentDelta = rewritten
|
||||
}
|
||||
}
|
||||
|
||||
updatedData, err := jsonMarshalRaw(payload)
|
||||
if err != nil {
|
||||
return "data: " + data, contentDelta, nil
|
||||
}
|
||||
return "data: " + string(updatedData), contentDelta, nil
|
||||
}
|
||||
|
||||
func extractSoraContent(payload map[string]any) (string, bool) {
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return "", false
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
return content, true
|
||||
}
|
||||
}
|
||||
if message, ok := choice["message"].(map[string]any); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
return content, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func applySoraContent(payload map[string]any, content string) {
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||
delta["content"] = content
|
||||
choice["delta"] = delta
|
||||
return
|
||||
}
|
||||
if message, ok := choice["message"].(map[string]any); ok {
|
||||
message["content"] = content
|
||||
choice["message"] = message
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
||||
if buffer == nil {
|
||||
return s.rewriteSoraContent(contentDelta)
|
||||
}
|
||||
if contentDelta == "" && *buffer == "" {
|
||||
return ""
|
||||
}
|
||||
combined := *buffer + contentDelta
|
||||
rewritten := s.rewriteSoraContent(combined)
|
||||
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
||||
if bufferStart < 0 {
|
||||
*buffer = ""
|
||||
return rewritten
|
||||
}
|
||||
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
||||
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
||||
}
|
||||
output := rewritten[:bufferStart]
|
||||
*buffer = rewritten[bufferStart:]
|
||||
return output
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
||||
minIndex := -1
|
||||
start := 0
|
||||
for {
|
||||
idx := strings.Index(content[start:], "![")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += start
|
||||
if !hasSoraImageMatchAt(content, idx) {
|
||||
if minIndex == -1 || idx < minIndex {
|
||||
minIndex = idx
|
||||
}
|
||||
}
|
||||
start = idx + 2
|
||||
}
|
||||
lower := strings.ToLower(content)
|
||||
start = 0
|
||||
for {
|
||||
idx := strings.Index(lower[start:], "<video")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += start
|
||||
if !hasSoraVideoMatchAt(content, idx) {
|
||||
if minIndex == -1 || idx < minIndex {
|
||||
minIndex = idx
|
||||
}
|
||||
}
|
||||
start = idx + len("<video")
|
||||
}
|
||||
return minIndex
|
||||
}
|
||||
|
||||
func hasSoraImageMatchAt(content string, idx int) bool {
|
||||
if idx < 0 || idx >= len(content) {
|
||||
return false
|
||||
}
|
||||
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
||||
return loc != nil && loc[0] == 0
|
||||
}
|
||||
|
||||
func hasSoraVideoMatchAt(content string, idx int) bool {
|
||||
if idx < 0 || idx >= len(content) {
|
||||
return false
|
||||
}
|
||||
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
||||
return loc != nil && loc[0] == 0
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||
if content == "" {
|
||||
return content
|
||||
}
|
||||
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||
sub := soraImageMarkdownRe.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
rewritten := s.rewriteSoraURL(sub[1])
|
||||
if rewritten == sub[1] {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, sub[1], rewritten, 1)
|
||||
})
|
||||
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||
sub := soraVideoHTMLRe.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
rewritten := s.rewriteSoraURL(sub[1])
|
||||
if rewritten == sub[1] {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, sub[1], rewritten, 1)
|
||||
})
|
||||
return content
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
||||
if buffer == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
rewritten := s.rewriteSoraContent(buffer)
|
||||
payload := map[string]any{
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"delta": map[string]any{
|
||||
"content": rewritten,
|
||||
},
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
if originalModel != "" {
|
||||
payload["model"] = originalModel
|
||||
}
|
||||
updatedData, err := jsonMarshalRaw(payload)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return "data: " + string(updatedData), rewritten, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
path := parsed.Path
|
||||
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
|
||||
return raw
|
||||
}
|
||||
return s.buildSoraMediaURL(path, parsed.RawQuery)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
|
||||
if content == "" {
|
||||
return "", nil
|
||||
}
|
||||
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
|
||||
return "video", []string{match[1]}
|
||||
}
|
||||
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
|
||||
if len(imageMatches) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
urls := make([]string, 0, len(imageMatches))
|
||||
for _, match := range imageMatches {
|
||||
if len(match) > 1 {
|
||||
urls = append(urls, match[1])
|
||||
}
|
||||
}
|
||||
return "image", urls
|
||||
}
|
||||
|
||||
func isSoraPromptEnhanceModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
|
||||
}
|
||||
@ -1,63 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SoraGeneration 代表一条 Sora 客户端生成记录。
|
||||
type SoraGeneration struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MediaType string `json:"media_type"` // video / image
|
||||
Status string `json:"status"` // pending / generating / completed / failed / cancelled
|
||||
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
|
||||
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
|
||||
FileSizeBytes int64 `json:"file_size_bytes"`
|
||||
StorageType string `json:"storage_type"` // s3 / local / upstream / none
|
||||
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
|
||||
UpstreamTaskID string `json:"upstream_task_id"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
// Sora 生成记录状态常量
|
||||
const (
|
||||
SoraGenStatusPending = "pending"
|
||||
SoraGenStatusGenerating = "generating"
|
||||
SoraGenStatusCompleted = "completed"
|
||||
SoraGenStatusFailed = "failed"
|
||||
SoraGenStatusCancelled = "cancelled"
|
||||
)
|
||||
|
||||
// Sora 存储类型常量
|
||||
const (
|
||||
SoraStorageTypeS3 = "s3"
|
||||
SoraStorageTypeLocal = "local"
|
||||
SoraStorageTypeUpstream = "upstream"
|
||||
SoraStorageTypeNone = "none"
|
||||
)
|
||||
|
||||
// SoraGenerationListParams 查询生成记录的参数。
|
||||
type SoraGenerationListParams struct {
|
||||
UserID int64
|
||||
Status string // 可选筛选
|
||||
StorageType string // 可选筛选
|
||||
MediaType string // 可选筛选
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// SoraGenerationRepository 生成记录持久化接口。
|
||||
type SoraGenerationRepository interface {
|
||||
Create(ctx context.Context, gen *SoraGeneration) error
|
||||
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
|
||||
Update(ctx context.Context, gen *SoraGeneration) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
|
||||
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
|
||||
}
|
||||
@ -1,332 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
|
||||
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
|
||||
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
|
||||
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
|
||||
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
|
||||
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
|
||||
)
|
||||
|
||||
const soraGenerationActiveLimit = 3
|
||||
|
||||
type soraGenerationRepoAtomicCreator interface {
|
||||
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
|
||||
}
|
||||
|
||||
type soraGenerationRepoConditionalUpdater interface {
|
||||
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
|
||||
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
|
||||
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
|
||||
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
|
||||
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
|
||||
}
|
||||
|
||||
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
|
||||
type SoraGenerationService struct {
|
||||
genRepo SoraGenerationRepository
|
||||
s3Storage *SoraS3Storage
|
||||
quotaService *SoraQuotaService
|
||||
}
|
||||
|
||||
// NewSoraGenerationService 创建生成记录服务。
|
||||
func NewSoraGenerationService(
|
||||
genRepo SoraGenerationRepository,
|
||||
s3Storage *SoraS3Storage,
|
||||
quotaService *SoraQuotaService,
|
||||
) *SoraGenerationService {
|
||||
return &SoraGenerationService{
|
||||
genRepo: genRepo,
|
||||
s3Storage: s3Storage,
|
||||
quotaService: quotaService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePending 创建一条 pending 状态的生成记录。
|
||||
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
|
||||
gen := &SoraGeneration{
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
MediaType: mediaType,
|
||||
Status: SoraGenStatusPending,
|
||||
StorageType: SoraStorageTypeNone,
|
||||
}
|
||||
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
|
||||
if err := atomicCreator.CreatePendingWithLimit(
|
||||
ctx,
|
||||
gen,
|
||||
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
|
||||
soraGenerationActiveLimit,
|
||||
); err != nil {
|
||||
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("create generation: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
if err := s.genRepo.Create(ctx, gen); err != nil {
|
||||
return nil, fmt.Errorf("create generation: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
// MarkGenerating 标记为生成中。
|
||||
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusGenerating
|
||||
gen.UpstreamTaskID = upstreamTaskID
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkCompleted 标记为已完成。
|
||||
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusCompleted
|
||||
gen.MediaURL = mediaURL
|
||||
gen.MediaURLs = mediaURLs
|
||||
gen.StorageType = storageType
|
||||
gen.S3ObjectKeys = s3Keys
|
||||
gen.FileSizeBytes = fileSizeBytes
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkFailed 标记为失败。
|
||||
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusFailed
|
||||
gen.ErrorMessage = errMsg
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkCancelled 标记为已取消。
|
||||
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationNotActive
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationNotActive
|
||||
}
|
||||
gen.Status = SoraGenStatusCancelled
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
|
||||
func (s *SoraGenerationService) UpdateStorageForCompleted(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
) error {
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusCompleted {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.MediaURL = mediaURL
|
||||
gen.MediaURLs = mediaURLs
|
||||
gen.StorageType = storageType
|
||||
gen.S3ObjectKeys = s3Keys
|
||||
gen.FileSizeBytes = fileSizeBytes
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// GetByID 获取记录详情(含权限校验)。
|
||||
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if gen.UserID != userID {
|
||||
return nil, fmt.Errorf("无权访问此生成记录")
|
||||
}
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
// List 查询生成记录列表(分页 + 筛选)。
|
||||
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
if params.PageSize > 100 {
|
||||
params.PageSize = 100
|
||||
}
|
||||
return s.genRepo.List(ctx, params)
|
||||
}
|
||||
|
||||
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
|
||||
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.UserID != userID {
|
||||
return fmt.Errorf("无权删除此生成记录")
|
||||
}
|
||||
|
||||
// 清理 S3 文件
|
||||
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
|
||||
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 释放配额(S3/本地均释放)
|
||||
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
|
||||
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.genRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
|
||||
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
|
||||
}
|
||||
|
||||
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
|
||||
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
|
||||
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
|
||||
return nil
|
||||
}
|
||||
if len(gen.S3ObjectKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
urls := make([]string, len(gen.S3ObjectKeys))
|
||||
var wg sync.WaitGroup
|
||||
var firstErr error
|
||||
var errMu sync.Mutex
|
||||
|
||||
for idx, key := range gen.S3ObjectKeys {
|
||||
wg.Add(1)
|
||||
go func(i int, objectKey string) {
|
||||
defer wg.Done()
|
||||
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
|
||||
if err != nil {
|
||||
errMu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
errMu.Unlock()
|
||||
return
|
||||
}
|
||||
urls[i] = url
|
||||
}(idx, key)
|
||||
}
|
||||
wg.Wait()
|
||||
if firstErr != nil {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
gen.MediaURL = urls[0]
|
||||
gen.MediaURLs = urls
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,881 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ==================== Stub: SoraGenerationRepository ====================
|
||||
|
||||
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
|
||||
|
||||
type stubGenRepo struct {
|
||||
gens map[int64]*SoraGeneration
|
||||
nextID int64
|
||||
createErr error
|
||||
getErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
listErr error
|
||||
countErr error
|
||||
countValue int64
|
||||
}
|
||||
|
||||
func newStubGenRepo() *stubGenRepo {
|
||||
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.createErr != nil {
|
||||
return r.createErr
|
||||
}
|
||||
gen.ID = r.nextID
|
||||
gen.CreatedAt = time.Now()
|
||||
r.nextID++
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
if gen, ok := r.gens[id]; ok {
|
||||
return gen, nil
|
||||
}
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
|
||||
if r.deleteErr != nil {
|
||||
return r.deleteErr
|
||||
}
|
||||
delete(r.gens, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
||||
if r.listErr != nil {
|
||||
return nil, 0, r.listErr
|
||||
}
|
||||
var result []*SoraGeneration
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID != params.UserID {
|
||||
continue
|
||||
}
|
||||
if params.Status != "" && gen.Status != params.Status {
|
||||
continue
|
||||
}
|
||||
if params.StorageType != "" && gen.StorageType != params.StorageType {
|
||||
continue
|
||||
}
|
||||
if params.MediaType != "" && gen.MediaType != params.MediaType {
|
||||
continue
|
||||
}
|
||||
result = append(result, gen)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if r.countErr != nil {
|
||||
return 0, r.countErr
|
||||
}
|
||||
if r.countValue > 0 {
|
||||
return r.countValue, nil
|
||||
}
|
||||
var count int64
|
||||
statusSet := make(map[string]struct{})
|
||||
for _, s := range statuses {
|
||||
statusSet[s] = struct{}{}
|
||||
}
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID == userID {
|
||||
if _, ok := statusSet[gen.Status]; ok {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
|
||||
|
||||
var _ UserRepository = (*stubUserRepoForQuota)(nil)
|
||||
|
||||
type stubUserRepoForQuota struct {
|
||||
users map[int64]*User
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func newStubUserRepoForQuota() *stubUserRepoForQuota {
|
||||
return &stubUserRepoForQuota{users: make(map[int64]*User)}
|
||||
}
|
||||
|
||||
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
|
||||
if u, ok := r.users[id]; ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
|
||||
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
|
||||
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
|
||||
|
||||
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
|
||||
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
|
||||
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
|
||||
storage := &SoraS3Storage{}
|
||||
storage.cfg = &SoraS3Settings{
|
||||
Enabled: true,
|
||||
Bucket: "test-bucket",
|
||||
CDNURL: cdnURL,
|
||||
}
|
||||
// 需要 non-nil client 使 getClient 命中缓存
|
||||
storage.client = s3.New(s3.Options{})
|
||||
return storage
|
||||
}
|
||||
|
||||
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
|
||||
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
|
||||
func newS3StorageFailingDelete() *SoraS3Storage {
|
||||
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
|
||||
}
|
||||
|
||||
// ==================== CreatePending ====================
|
||||
|
||||
func TestCreatePending_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, int64(1), gen.UserID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
require.Equal(t, "一只猫跳舞", gen.Prompt)
|
||||
require.Equal(t, "video", gen.MediaType)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
|
||||
require.Nil(t, gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_WithAPIKeyID(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
apiKeyID := int64(42)
|
||||
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen.APIKeyID)
|
||||
require.Equal(t, int64(42), *gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_RepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.createErr = fmt.Errorf("db write error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "create generation")
|
||||
}
|
||||
|
||||
// ==================== MarkGenerating ====================
|
||||
|
||||
func TestMarkGenerating_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
|
||||
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 999, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCompleted ====================
|
||||
|
||||
func TestMarkCompleted_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1,
|
||||
"https://cdn.example.com/video.mp4",
|
||||
[]string{"https://cdn.example.com/video.mp4"},
|
||||
SoraStorageTypeS3,
|
||||
[]string{"sora/1/2024/01/01/uuid.mp4"},
|
||||
1048576,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusCompleted, gen.Status)
|
||||
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
|
||||
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
|
||||
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
|
||||
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
|
||||
require.Equal(t, int64(1048576), gen.FileSizeBytes)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkFailed ====================
|
||||
|
||||
func TestMarkFailed_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusFailed, gen.Status)
|
||||
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkFailed_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 999, "error")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkFailed_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "err")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCancelled ====================
|
||||
|
||||
func TestMarkCancelled_Pending(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
require.NotNil(t, repo.gens[1].CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Generating(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Completed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Failed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== GetByID ====================
|
||||
|
||||
func TestGetByID_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
}
|
||||
|
||||
func TestGetByID_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestGetByID_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
}
|
||||
|
||||
// ==================== List ====================
|
||||
|
||||
func TestList_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gens, 2) // 只有 userID=1 的
|
||||
require.Equal(t, int64(2), total)
|
||||
}
|
||||
|
||||
func TestList_DefaultPagination(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_MaxPageSize(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// pageSize > 100 → 应限制为 100
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.listErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== Delete ====================
|
||||
|
||||
func TestDelete_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
func TestDelete_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // s3Storage 为 nil,跳过清理
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // quotaService 为 nil,跳过释放
|
||||
}
|
||||
|
||||
func TestDelete_NonS3NoCleanup(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_DeleteRepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
|
||||
repo.deleteErr = fmt.Errorf("delete failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== CountActiveByUser ====================
|
||||
|
||||
func TestCountActiveByUser_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_NoActive(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.countErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs ====================
|
||||
|
||||
func TestResolveMediaURLs_NilGen(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_NonS3(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_Local(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
// ==================== 状态流转完整测试 ====================
|
||||
|
||||
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// 1. 创建 pending
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
|
||||
// 2. 标记 generating
|
||||
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
|
||||
|
||||
// 3. 标记 completed
|
||||
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
|
||||
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
|
||||
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
// ==================== 权限隔离测试 ====================
|
||||
|
||||
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
// 用户 2 尝试访问用户 1 的记录
|
||||
_, err := svc.GetByID(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
err := svc.Delete(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
// ==================== Delete: S3 清理 + 配额释放路径 ====================
|
||||
|
||||
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
|
||||
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
|
||||
// 验证 Delete 仍然成功(S3 错误只是记录日志)
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
|
||||
}
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(repo, s3Storage, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // S3 清理失败不影响删除
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
|
||||
// 有配额服务时,删除 S3 类型记录会释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 1048576, // 1MB
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
// 配额应被释放: 2MB - 1MB = 1MB
|
||||
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
|
||||
// S3 清理 + 配额释放同时触发
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"key1"},
|
||||
FileSizeBytes: 512,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
|
||||
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
|
||||
// 本地存储同样需要释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeLocal,
|
||||
FileSizeBytes: 1024,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
|
||||
// FileSizeBytes=0 跳过配额释放
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 0,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
"sora/1/2024/01/01/img3.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
// 主 URL 更新为第一个 key 的 CDN URL
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
|
||||
// 多图 URLs 全部更新
|
||||
require.Len(t, gen.MediaURLs, 3)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
|
||||
// 使用无 settingService 的 S3 Storage,getClient 会失败
|
||||
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // GetAccessURL 失败应传播错误
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
|
||||
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
|
||||
}
|
||||
@ -1,120 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
|
||||
// SoraMediaCleanupService 定期清理本地媒体文件
|
||||
type SoraMediaCleanupService struct {
|
||||
storage *SoraMediaStorage
|
||||
cfg *config.Config
|
||||
|
||||
cron *cron.Cron
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||
return &SoraMediaCleanupService{
|
||||
storage: storage,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraMediaCleanupService) Start() {
|
||||
if s == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Sora.Storage.Cleanup.Enabled {
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.storage == nil || !s.storage.Enabled() {
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (storage disabled)")
|
||||
return
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
|
||||
if schedule == "" {
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (empty schedule)")
|
||||
return
|
||||
}
|
||||
loc := time.Local
|
||||
if strings.TrimSpace(s.cfg.Timezone) != "" {
|
||||
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
|
||||
loc = parsed
|
||||
}
|
||||
}
|
||||
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
|
||||
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SoraMediaCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.cron != nil {
|
||||
ctx := s.cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SoraMediaCleanupService) runCleanup() {
|
||||
if s.cfg == nil || s.storage == nil {
|
||||
return
|
||||
}
|
||||
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
||||
if retention <= 0 {
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] skipped (retention_days=%d)", retention)
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -retention)
|
||||
deleted := 0
|
||||
|
||||
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
|
||||
for _, root := range roots {
|
||||
if root == "" {
|
||||
continue
|
||||
}
|
||||
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if info.ModTime().Before(cutoff) {
|
||||
if rmErr := os.Remove(p); rmErr == nil {
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cleanup finished, deleted=%d", deleted)
|
||||
}
|
||||
@ -1,207 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSoraMediaCleanupService_RunCleanup_NilCfg(t *testing.T) {
|
||||
storage := &SoraMediaStorage{}
|
||||
svc := &SoraMediaCleanupService{storage: storage, cfg: nil}
|
||||
// 不应 panic
|
||||
svc.runCleanup()
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_RunCleanup_NilStorage(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
svc := &SoraMediaCleanupService{storage: nil, cfg: cfg}
|
||||
// 不应 panic
|
||||
svc.runCleanup()
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_RunCleanup_ZeroRetention(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
RetentionDays: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
// retention=0 应跳过清理
|
||||
svc.runCleanup()
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_NilCfg(t *testing.T) {
|
||||
svc := NewSoraMediaCleanupService(nil, nil)
|
||||
svc.Start() // cfg == nil 时应直接返回
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_StorageDisabled(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraMediaCleanupService(nil, cfg)
|
||||
svc.Start() // storage == nil 时应直接返回
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_WithTimezone(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Timezone: "Asia/Shanghai",
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
Schedule: "0 3 * * *",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
svc.Start()
|
||||
t.Cleanup(svc.Stop)
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_Disabled(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraMediaCleanupService(nil, cfg)
|
||||
svc.Start() // 不应 panic,也不应启动 cron
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_NilSelf(t *testing.T) {
|
||||
var svc *SoraMediaCleanupService
|
||||
svc.Start() // 不应 panic
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_EmptySchedule(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
Schedule: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
svc.Start() // 空 schedule 不应启动
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_InvalidSchedule(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
Schedule: "invalid-cron",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
svc.Start() // 无效 schedule 不应 panic
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Start_ValidSchedule(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
Schedule: "0 3 * * *",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
svc.Start()
|
||||
t.Cleanup(svc.Stop)
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Stop_NilSelf(t *testing.T) {
|
||||
var svc *SoraMediaCleanupService
|
||||
svc.Stop() // 不应 panic
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_Stop_WithoutStart(t *testing.T) {
|
||||
svc := NewSoraMediaCleanupService(nil, &config.Config{})
|
||||
svc.Stop() // cron 未启动时 Stop 不应 panic
|
||||
}
|
||||
|
||||
func TestSoraMediaCleanupService_RunCleanup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
Cleanup: config.SoraStorageCleanupConfig{
|
||||
Enabled: true,
|
||||
RetentionDays: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
require.NoError(t, storage.EnsureLocalDirs())
|
||||
|
||||
oldImage := filepath.Join(storage.ImageRoot(), "old.png")
|
||||
newVideo := filepath.Join(storage.VideoRoot(), "new.mp4")
|
||||
require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644))
|
||||
require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644))
|
||||
|
||||
oldTime := time.Now().Add(-48 * time.Hour)
|
||||
require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime))
|
||||
|
||||
cleanup := NewSoraMediaCleanupService(storage, cfg)
|
||||
cleanup.runCleanup()
|
||||
|
||||
require.NoFileExists(t, oldImage)
|
||||
require.FileExists(t, newVideo)
|
||||
}
|
||||
@ -1,48 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SignSoraMediaURL 生成 Sora 媒体临时签名
|
||||
func SignSoraMediaURL(path string, query string, expires int64, key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := mac.Write([]byte("|")); err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifySoraMediaURL 校验 Sora 媒体签名
|
||||
func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool {
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
expected := SignSoraMediaURL(path, query, expires, key)
|
||||
if expected == "" {
|
||||
return false
|
||||
}
|
||||
return hmac.Equal([]byte(signature), []byte(expected))
|
||||
}
|
||||
|
||||
func buildSoraMediaSignPayload(path string, query string) string {
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + query
|
||||
}
|
||||
@ -1,34 +0,0 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSoraMediaSignVerify(t *testing.T) {
|
||||
key := "test-key"
|
||||
path := "/tmp/abc.png"
|
||||
query := "a=1&b=2"
|
||||
expires := int64(1700000000)
|
||||
|
||||
signature := SignSoraMediaURL(path, query, expires, key)
|
||||
if signature == "" {
|
||||
t.Fatal("签名为空")
|
||||
}
|
||||
if !VerifySoraMediaURL(path, query, expires, signature, key) {
|
||||
t.Fatal("签名校验失败")
|
||||
}
|
||||
if VerifySoraMediaURL(path, "a=1", expires, signature, key) {
|
||||
t.Fatal("签名参数不同仍然通过")
|
||||
}
|
||||
if VerifySoraMediaURL(path, query, expires+1, signature, key) {
|
||||
t.Fatal("签名过期校验未失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraMediaSignWithEmptyKey(t *testing.T) {
|
||||
signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "")
|
||||
if signature != "" {
|
||||
t.Fatalf("空密钥不应生成签名")
|
||||
}
|
||||
if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") {
|
||||
t.Fatalf("空密钥不应通过校验")
|
||||
}
|
||||
}
|
||||
@ -1,381 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
soraStorageDefaultRoot = "/app/data/sora"
|
||||
)
|
||||
|
||||
// SoraMediaStorage 负责下载并落地 Sora 媒体
|
||||
type SoraMediaStorage struct {
|
||||
cfg *config.Config
|
||||
root string
|
||||
imageRoot string
|
||||
videoRoot string
|
||||
downloadTimeout time.Duration
|
||||
maxDownloadBytes int64
|
||||
fallbackToUpstream bool
|
||||
debug bool
|
||||
sem chan struct{}
|
||||
ready bool
|
||||
}
|
||||
|
||||
func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
storage := &SoraMediaStorage{cfg: cfg}
|
||||
storage.refreshConfig()
|
||||
if storage.Enabled() {
|
||||
if err := storage.EnsureLocalDirs(); err != nil {
|
||||
log.Printf("[SoraStorage] 初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
return storage
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) Enabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local"
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) Root() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.root
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) ImageRoot() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.imageRoot
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) VideoRoot() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.videoRoot
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) refreshConfig() {
|
||||
if s == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath)
|
||||
if root == "" {
|
||||
root = soraStorageDefaultRoot
|
||||
}
|
||||
root = filepath.Clean(root)
|
||||
if !filepath.IsAbs(root) {
|
||||
if absRoot, err := filepath.Abs(root); err == nil {
|
||||
root = absRoot
|
||||
}
|
||||
}
|
||||
s.root = root
|
||||
s.imageRoot = filepath.Join(root, "image")
|
||||
s.videoRoot = filepath.Join(root, "video")
|
||||
|
||||
maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads
|
||||
if maxConcurrent <= 0 {
|
||||
maxConcurrent = 4
|
||||
}
|
||||
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 120
|
||||
}
|
||||
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
|
||||
|
||||
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 200 << 20
|
||||
}
|
||||
s.maxDownloadBytes = maxBytes
|
||||
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
|
||||
s.debug = s.cfg.Sora.Storage.Debug
|
||||
s.sem = make(chan struct{}, maxConcurrent)
|
||||
}
|
||||
|
||||
// EnsureLocalDirs 创建并校验本地目录
|
||||
func (s *SoraMediaStorage) EnsureLocalDirs() error {
|
||||
if s == nil || !s.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if err := os.MkdirAll(s.imageRoot, 0o755); err != nil {
|
||||
return fmt.Errorf("create image dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(s.videoRoot, 0o755); err != nil {
|
||||
return fmt.Errorf("create video dir: %w", err)
|
||||
}
|
||||
s.ready = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
|
||||
func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) {
|
||||
if len(urls) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if s == nil || !s.Enabled() {
|
||||
return urls, nil
|
||||
}
|
||||
if !s.ready {
|
||||
if err := s.EnsureLocalDirs(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
results := make([]string, 0, len(urls))
|
||||
for _, raw := range urls {
|
||||
relative, err := s.downloadAndStore(ctx, mediaType, raw)
|
||||
if err != nil {
|
||||
if s.fallbackToUpstream {
|
||||
results = append(results, raw)
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, relative)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
|
||||
func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
|
||||
if s == nil || len(paths) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var total int64
|
||||
for _, p := range paths {
|
||||
localPath, err := s.resolveLocalPath(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
if info.Mode().IsRegular() {
|
||||
total += info.Size()
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
|
||||
func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
|
||||
if s == nil || len(paths) == 0 {
|
||||
return nil
|
||||
}
|
||||
var lastErr error
|
||||
for _, p := range paths {
|
||||
localPath, err := s.resolveLocalPath(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
|
||||
if s == nil || strings.TrimSpace(relativePath) == "" {
|
||||
return "", errors.New("empty path")
|
||||
}
|
||||
cleaned := path.Clean(relativePath)
|
||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||
return "", errors.New("not a local media path")
|
||||
}
|
||||
if strings.TrimSpace(s.root) == "" {
|
||||
return "", errors.New("storage root not configured")
|
||||
}
|
||||
relative := strings.TrimPrefix(cleaned, "/")
|
||||
return filepath.Join(s.root, filepath.FromSlash(relative)), nil
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
|
||||
if strings.TrimSpace(rawURL) == "" {
|
||||
return "", errors.New("empty url")
|
||||
}
|
||||
root := s.imageRoot
|
||||
if mediaType == "video" {
|
||||
root = s.videoRoot
|
||||
}
|
||||
if root == "" {
|
||||
return "", errors.New("storage root not configured")
|
||||
}
|
||||
|
||||
retries := 3
|
||||
for attempt := 1; attempt <= retries; attempt++ {
|
||||
release, err := s.acquire(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
relative, err := s.downloadOnce(ctx, root, mediaType, rawURL)
|
||||
release()
|
||||
if err == nil {
|
||||
return relative, nil
|
||||
}
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
|
||||
}
|
||||
if attempt < retries {
|
||||
time.Sleep(time.Duration(attempt*attempt) * time.Second)
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return "", errors.New("download retries exhausted")
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client := &http.Client{Timeout: s.downloadTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ext := normalizeSoraFileExt(fileExtFromURL(rawURL))
|
||||
if ext == "" {
|
||||
ext = normalizeSoraFileExt(fileExtFromContentType(resp.Header.Get("Content-Type")))
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
|
||||
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
|
||||
}
|
||||
|
||||
storageRoot, err := os.OpenRoot(root)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = storageRoot.Close() }()
|
||||
|
||||
datePath := time.Now().Format("2006/01/02")
|
||||
datePathFS := filepath.FromSlash(datePath)
|
||||
if err := storageRoot.MkdirAll(datePathFS, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
filename := uuid.NewString() + ext
|
||||
filePath := filepath.Join(datePathFS, filename)
|
||||
out, err := storageRoot.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
removePartialDownload(storageRoot, filePath)
|
||||
return "", err
|
||||
}
|
||||
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
|
||||
removePartialDownload(storageRoot, filePath)
|
||||
return "", fmt.Errorf("download size exceeds limit: %d", written)
|
||||
}
|
||||
|
||||
relative := path.Join("/", mediaType, datePath, filename)
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
|
||||
}
|
||||
return relative, nil
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) {
|
||||
if s.sem == nil {
|
||||
return func() {}, nil
|
||||
}
|
||||
select {
|
||||
case s.sem <- struct{}{}:
|
||||
return func() { <-s.sem }, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func fileExtFromURL(raw string) string {
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
ext := path.Ext(parsed.Path)
|
||||
return strings.ToLower(ext)
|
||||
}
|
||||
|
||||
func fileExtFromContentType(ct string) string {
|
||||
if ct == "" {
|
||||
return ""
|
||||
}
|
||||
if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 {
|
||||
return strings.ToLower(exts[0])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeSoraFileExt(ext string) string {
|
||||
ext = strings.ToLower(strings.TrimSpace(ext))
|
||||
switch ext {
|
||||
case ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg", ".tif", ".tiff", ".heic",
|
||||
".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
|
||||
return ext
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func removePartialDownload(root *os.Root, filePath string) {
|
||||
if root == nil || strings.TrimSpace(filePath) == "" {
|
||||
return
|
||||
}
|
||||
_ = root.Remove(filePath)
|
||||
}
|
||||
|
||||
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
|
||||
func sanitizeMediaLogURL(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
if len(rawURL) > 80 {
|
||||
return rawURL[:80] + "..."
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
|
||||
if len(safe) > 120 {
|
||||
return safe[:120] + "..."
|
||||
}
|
||||
return safe
|
||||
}
|
||||
@ -1,119 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSoraMediaStorage_StoreFromURLs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("data"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
MaxConcurrentDownloads: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, urls, 1)
|
||||
require.True(t, strings.HasPrefix(urls[0], "/image/"))
|
||||
require.True(t, strings.HasSuffix(urls[0], ".png"))
|
||||
|
||||
localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/")))
|
||||
require.FileExists(t, localPath)
|
||||
}
|
||||
|
||||
func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
FallbackToUpstream: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
url := server.URL + "/broken.png"
|
||||
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{url}, urls)
|
||||
}
|
||||
|
||||
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("too-large"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
MaxDownloadBytes: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNormalizeSoraFileExt(t *testing.T) {
|
||||
require.Equal(t, ".png", normalizeSoraFileExt(".PNG"))
|
||||
require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4"))
|
||||
require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd"))
|
||||
require.Equal(t, "", normalizeSoraFileExt(".php"))
|
||||
}
|
||||
|
||||
func TestRemovePartialDownload(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
root, err := os.OpenRoot(tmpDir)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = root.Close() }()
|
||||
|
||||
filePath := "partial.bin"
|
||||
f, err := root.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
require.NoError(t, err)
|
||||
_, _ = f.WriteString("partial")
|
||||
_ = f.Close()
|
||||
|
||||
removePartialDownload(root, filePath)
|
||||
_, err = root.Stat(filePath)
|
||||
require.Error(t, err)
|
||||
require.True(t, os.IsNotExist(err))
|
||||
}
|
||||
@ -1,488 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// SoraModelConfig Sora 模型配置
|
||||
type SoraModelConfig struct {
|
||||
Type string
|
||||
Width int
|
||||
Height int
|
||||
Orientation string
|
||||
Frames int
|
||||
Model string
|
||||
Size string
|
||||
RequirePro bool
|
||||
// Prompt-enhance 专用参数
|
||||
ExpansionLevel string
|
||||
DurationS int
|
||||
}
|
||||
|
||||
var soraModelConfigs = map[string]SoraModelConfig{
|
||||
"gpt-image": {
|
||||
Type: "image",
|
||||
Width: 360,
|
||||
Height: 360,
|
||||
},
|
||||
"gpt-image-landscape": {
|
||||
Type: "image",
|
||||
Width: 540,
|
||||
Height: 360,
|
||||
},
|
||||
"gpt-image-portrait": {
|
||||
Type: "image",
|
||||
Width: 360,
|
||||
Height: 540,
|
||||
},
|
||||
"sora2-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 300,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
},
|
||||
"sora2-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 300,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
},
|
||||
"sora2-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 450,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
},
|
||||
"sora2-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 450,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
},
|
||||
"sora2-landscape-25s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 750,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2-portrait-25s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 750,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-25s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 750,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-25s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 750,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
Frames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
Frames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"prompt-enhance-short-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-short-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-short-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-medium-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-medium-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-medium-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-long-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-long-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-long-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 20,
|
||||
},
|
||||
}
|
||||
|
||||
var soraModelIDs = []string{
|
||||
"gpt-image",
|
||||
"gpt-image-landscape",
|
||||
"gpt-image-portrait",
|
||||
"sora2-landscape-10s",
|
||||
"sora2-portrait-10s",
|
||||
"sora2-landscape-15s",
|
||||
"sora2-portrait-15s",
|
||||
"sora2-landscape-25s",
|
||||
"sora2-portrait-25s",
|
||||
"sora2pro-landscape-10s",
|
||||
"sora2pro-portrait-10s",
|
||||
"sora2pro-landscape-15s",
|
||||
"sora2pro-portrait-15s",
|
||||
"sora2pro-landscape-25s",
|
||||
"sora2pro-portrait-25s",
|
||||
"sora2pro-hd-landscape-10s",
|
||||
"sora2pro-hd-portrait-10s",
|
||||
"sora2pro-hd-landscape-15s",
|
||||
"sora2pro-hd-portrait-15s",
|
||||
"prompt-enhance-short-10s",
|
||||
"prompt-enhance-short-15s",
|
||||
"prompt-enhance-short-20s",
|
||||
"prompt-enhance-medium-10s",
|
||||
"prompt-enhance-medium-15s",
|
||||
"prompt-enhance-medium-20s",
|
||||
"prompt-enhance-long-10s",
|
||||
"prompt-enhance-long-15s",
|
||||
"prompt-enhance-long-20s",
|
||||
}
|
||||
|
||||
// GetSoraModelConfig 返回 Sora 模型配置
|
||||
func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
|
||||
key := strings.ToLower(strings.TrimSpace(model))
|
||||
cfg, ok := soraModelConfigs[key]
|
||||
return cfg, ok
|
||||
}
|
||||
|
||||
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
|
||||
type SoraModelFamily struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Orientations []string `json:"orientations"`
|
||||
Durations []int `json:"durations,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
|
||||
imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
|
||||
|
||||
soraFamilyNames = map[string]string{
|
||||
"sora2": "Sora 2",
|
||||
"sora2pro": "Sora 2 Pro",
|
||||
"sora2pro-hd": "Sora 2 Pro HD",
|
||||
"gpt-image": "GPT Image",
|
||||
}
|
||||
)
|
||||
|
||||
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
|
||||
func BuildSoraModelFamilies() []SoraModelFamily {
|
||||
type familyData struct {
|
||||
modelType string
|
||||
orientations map[string]bool
|
||||
durations map[int]bool
|
||||
}
|
||||
families := make(map[string]*familyData)
|
||||
|
||||
for id, cfg := range soraModelConfigs {
|
||||
if cfg.Type == "prompt_enhance" {
|
||||
continue
|
||||
}
|
||||
var famID, orientation string
|
||||
var duration int
|
||||
|
||||
switch cfg.Type {
|
||||
case "video":
|
||||
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
duration, _ = strconv.Atoi(m[2])
|
||||
}
|
||||
case "image":
|
||||
if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
} else {
|
||||
famID = id
|
||||
orientation = "square"
|
||||
}
|
||||
}
|
||||
if famID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fd, ok := families[famID]
|
||||
if !ok {
|
||||
fd = &familyData{
|
||||
modelType: cfg.Type,
|
||||
orientations: make(map[string]bool),
|
||||
durations: make(map[int]bool),
|
||||
}
|
||||
families[famID] = fd
|
||||
}
|
||||
if orientation != "" {
|
||||
fd.orientations[orientation] = true
|
||||
}
|
||||
if duration > 0 {
|
||||
fd.durations[duration] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 排序:视频在前、图像在后,同类按名称排序
|
||||
famIDs := make([]string, 0, len(families))
|
||||
for id := range families {
|
||||
famIDs = append(famIDs, id)
|
||||
}
|
||||
sort.Slice(famIDs, func(i, j int) bool {
|
||||
fi, fj := families[famIDs[i]], families[famIDs[j]]
|
||||
if fi.modelType != fj.modelType {
|
||||
return fi.modelType == "video"
|
||||
}
|
||||
return famIDs[i] < famIDs[j]
|
||||
})
|
||||
|
||||
result := make([]SoraModelFamily, 0, len(famIDs))
|
||||
for _, famID := range famIDs {
|
||||
fd := families[famID]
|
||||
fam := SoraModelFamily{
|
||||
ID: famID,
|
||||
Name: soraFamilyNames[famID],
|
||||
Type: fd.modelType,
|
||||
}
|
||||
if fam.Name == "" {
|
||||
fam.Name = famID
|
||||
}
|
||||
for o := range fd.orientations {
|
||||
fam.Orientations = append(fam.Orientations, o)
|
||||
}
|
||||
sort.Strings(fam.Orientations)
|
||||
for d := range fd.durations {
|
||||
fam.Durations = append(fam.Durations, d)
|
||||
}
|
||||
sort.Ints(fam.Durations)
|
||||
result = append(result, fam)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
|
||||
// 通过命名约定自动识别视频/图像模型并分组。
|
||||
func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
|
||||
type familyData struct {
|
||||
modelType string
|
||||
orientations map[string]bool
|
||||
durations map[int]bool
|
||||
}
|
||||
families := make(map[string]*familyData)
|
||||
|
||||
for _, id := range modelIDs {
|
||||
id = strings.ToLower(strings.TrimSpace(id))
|
||||
if id == "" || strings.HasPrefix(id, "prompt-enhance") {
|
||||
continue
|
||||
}
|
||||
|
||||
var famID, orientation, modelType string
|
||||
var duration int
|
||||
|
||||
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
// 视频模型: {family}-{orientation}-{duration}s
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
duration, _ = strconv.Atoi(m[2])
|
||||
modelType = "video"
|
||||
} else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
// 图像模型(带方向): {family}-{orientation}
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
modelType = "image"
|
||||
} else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
|
||||
// 已知的无后缀图像模型(如 gpt-image)
|
||||
famID = id
|
||||
orientation = "square"
|
||||
modelType = "image"
|
||||
} else if strings.Contains(id, "image") {
|
||||
// 未知但名称包含 image 的模型,推断为图像模型
|
||||
famID = id
|
||||
orientation = "square"
|
||||
modelType = "image"
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
if famID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fd, ok := families[famID]
|
||||
if !ok {
|
||||
fd = &familyData{
|
||||
modelType: modelType,
|
||||
orientations: make(map[string]bool),
|
||||
durations: make(map[int]bool),
|
||||
}
|
||||
families[famID] = fd
|
||||
}
|
||||
if orientation != "" {
|
||||
fd.orientations[orientation] = true
|
||||
}
|
||||
if duration > 0 {
|
||||
fd.durations[duration] = true
|
||||
}
|
||||
}
|
||||
|
||||
famIDs := make([]string, 0, len(families))
|
||||
for id := range families {
|
||||
famIDs = append(famIDs, id)
|
||||
}
|
||||
sort.Slice(famIDs, func(i, j int) bool {
|
||||
fi, fj := families[famIDs[i]], families[famIDs[j]]
|
||||
if fi.modelType != fj.modelType {
|
||||
return fi.modelType == "video"
|
||||
}
|
||||
return famIDs[i] < famIDs[j]
|
||||
})
|
||||
|
||||
result := make([]SoraModelFamily, 0, len(famIDs))
|
||||
for _, famID := range famIDs {
|
||||
fd := families[famID]
|
||||
fam := SoraModelFamily{
|
||||
ID: famID,
|
||||
Name: soraFamilyNames[famID],
|
||||
Type: fd.modelType,
|
||||
}
|
||||
if fam.Name == "" {
|
||||
fam.Name = famID
|
||||
}
|
||||
for o := range fd.orientations {
|
||||
fam.Orientations = append(fam.Orientations, o)
|
||||
}
|
||||
sort.Strings(fam.Orientations)
|
||||
for d := range fd.durations {
|
||||
fam.Durations = append(fam.Durations, d)
|
||||
}
|
||||
sort.Ints(fam.Durations)
|
||||
result = append(result, fam)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DefaultSoraModels returns the default Sora model list.
|
||||
func DefaultSoraModels(cfg *config.Config) []openai.Model {
|
||||
models := make([]openai.Model, 0, len(soraModelIDs))
|
||||
for _, id := range soraModelIDs {
|
||||
models = append(models, openai.Model{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
OwnedBy: "openai",
|
||||
Type: "model",
|
||||
DisplayName: id,
|
||||
})
|
||||
}
|
||||
if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance {
|
||||
filtered := models[:0]
|
||||
for _, model := range models {
|
||||
if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
models = filtered
|
||||
}
|
||||
return models
|
||||
}
|
||||
@ -1,257 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// SoraQuotaService 管理 Sora 用户存储配额。
|
||||
// 配额优先级:用户级 → 分组级 → 系统默认值。
|
||||
type SoraQuotaService struct {
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
settingService *SettingService
|
||||
}
|
||||
|
||||
// NewSoraQuotaService 创建配额服务实例。
|
||||
func NewSoraQuotaService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
settingService *SettingService,
|
||||
) *SoraQuotaService {
|
||||
return &SoraQuotaService{
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
settingService: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaInfo 返回给客户端的配额信息。
|
||||
type QuotaInfo struct {
|
||||
QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
|
||||
UsedBytes int64 `json:"used_bytes"` // 已使用
|
||||
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0)
|
||||
QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
|
||||
Source string `json:"source,omitempty"` // 兼容旧字段
|
||||
}
|
||||
|
||||
// ErrSoraStorageQuotaExceeded 表示配额不足。
|
||||
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
|
||||
|
||||
// QuotaExceededError 包含配额不足的上下文信息。
|
||||
type QuotaExceededError struct {
|
||||
QuotaBytes int64
|
||||
UsedBytes int64
|
||||
}
|
||||
|
||||
func (e *QuotaExceededError) Error() string {
|
||||
if e == nil {
|
||||
return "存储配额不足"
|
||||
}
|
||||
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
|
||||
}
|
||||
|
||||
type soraQuotaAtomicUserRepository interface {
|
||||
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
|
||||
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
|
||||
}
|
||||
|
||||
// GetQuota 获取用户的存储配额信息。
|
||||
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
|
||||
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
info := &QuotaInfo{
|
||||
UsedBytes: user.SoraStorageUsedBytes,
|
||||
}
|
||||
|
||||
// 1. 用户级配额
|
||||
if user.SoraStorageQuotaBytes > 0 {
|
||||
info.QuotaBytes = user.SoraStorageQuotaBytes
|
||||
info.QuotaSource = "user"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// 2. 分组级配额(取用户可用分组中最大的配额)
|
||||
if len(user.AllowedGroups) > 0 {
|
||||
var maxGroupQuota int64
|
||||
for _, gid := range user.AllowedGroups {
|
||||
group, err := s.groupRepo.GetByID(ctx, gid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if group.SoraStorageQuotaBytes > maxGroupQuota {
|
||||
maxGroupQuota = group.SoraStorageQuotaBytes
|
||||
}
|
||||
}
|
||||
if maxGroupQuota > 0 {
|
||||
info.QuotaBytes = maxGroupQuota
|
||||
info.QuotaSource = "group"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 系统默认值
|
||||
defaultQuota := s.getSystemDefaultQuota(ctx)
|
||||
if defaultQuota > 0 {
|
||||
info.QuotaBytes = defaultQuota
|
||||
info.QuotaSource = "system"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// 无配额限制
|
||||
info.QuotaSource = "unlimited"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = 0
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// CheckQuota 检查用户是否有足够的存储配额。
|
||||
// 返回 nil 表示配额充足或无限制。
|
||||
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
|
||||
quota, err := s.GetQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 0 表示无限制
|
||||
if quota.QuotaBytes == 0 {
|
||||
return nil
|
||||
}
|
||||
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddUsage 原子累加用量(上传成功后调用)。
|
||||
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
|
||||
if bytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
quota, err := s.GetQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
|
||||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||||
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("update user quota usage (atomic): %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user for quota update: %w", err)
|
||||
}
|
||||
user.SoraStorageUsedBytes += bytes
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user quota usage: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReleaseUsage 释放用量(删除文件后调用)。
|
||||
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
|
||||
if bytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||||
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user quota release (atomic): %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user for quota release: %w", err)
|
||||
}
|
||||
user.SoraStorageUsedBytes -= bytes
|
||||
if user.SoraStorageUsedBytes < 0 {
|
||||
user.SoraStorageUsedBytes = 0
|
||||
}
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user quota release: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
|
||||
if quotaBytes <= 0 {
|
||||
return 0
|
||||
}
|
||||
if usedBytes >= quotaBytes {
|
||||
return 0
|
||||
}
|
||||
return quotaBytes - usedBytes
|
||||
}
|
||||
|
||||
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
|
||||
if s.settingService == nil {
|
||||
return 0
|
||||
}
|
||||
settings, err := s.settingService.GetSoraS3Settings(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return settings.DefaultStorageQuotaBytes
|
||||
}
|
||||
|
||||
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
|
||||
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
|
||||
return s.getSystemDefaultQuota(ctx)
|
||||
}
|
||||
|
||||
// SetUserQuota 设置用户级配额(管理员操作)。
|
||||
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
|
||||
user, err := userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.SoraStorageQuotaBytes = quotaBytes
|
||||
return userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// ParseQuotaBytes 解析配额字符串为字节数。
|
||||
func ParseQuotaBytes(s string) int64 {
|
||||
v, _ := strconv.ParseInt(s, 10, 64)
|
||||
return v
|
||||
}
|
||||
@ -1,492 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
|
||||
|
||||
var _ GroupRepository = (*stubGroupRepoForQuota)(nil)
|
||||
|
||||
type stubGroupRepoForQuota struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func newStubGroupRepoForQuota() *stubGroupRepoForQuota {
|
||||
return &stubGroupRepoForQuota{groups: make(map[int64]*Group)}
|
||||
}
|
||||
|
||||
func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := r.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, fmt.Errorf("group not found")
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil }
|
||||
func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
return r.GetByID(context.Background(), id)
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil }
|
||||
func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil }
|
||||
func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Stub: SettingRepository (用于 SettingService) ====================
|
||||
|
||||
var _ SettingRepository = (*stubSettingRepoForQuota)(nil)
|
||||
|
||||
type stubSettingRepoForQuota struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota {
|
||||
if values == nil {
|
||||
values = make(map[string]string)
|
||||
}
|
||||
return &stubSettingRepoForQuota{values: values}
|
||||
}
|
||||
|
||||
func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) {
|
||||
if v, ok := r.values[key]; ok {
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) {
|
||||
if v, ok := r.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error {
|
||||
r.values[key] = value
|
||||
return nil
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := r.values[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
for k, v := range settings {
|
||||
r.values[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) {
|
||||
return r.values, nil
|
||||
}
|
||||
func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error {
|
||||
delete(r.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== GetQuota ====================
|
||||
|
||||
func TestGetQuota_UserLevel(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB
|
||||
SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB
|
||||
}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10*1024*1024), quota.QuotaBytes)
|
||||
require.Equal(t, int64(3*1024*1024), quota.UsedBytes)
|
||||
require.Equal(t, "user", quota.Source)
|
||||
}
|
||||
|
||||
func TestGetQuota_GroupLevel(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 0, // 用户级无配额
|
||||
SoraStorageUsedBytes: 1024,
|
||||
AllowedGroups: []int64{10, 20},
|
||||
}
|
||||
|
||||
groupRepo := newStubGroupRepoForQuota()
|
||||
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024}
|
||||
groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024}
|
||||
|
||||
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值
|
||||
require.Equal(t, "group", quota.Source)
|
||||
}
|
||||
|
||||
func TestGetQuota_SystemLevel(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512}
|
||||
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
svc := NewSoraQuotaService(userRepo, nil, settingService)
|
||||
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(104857600), quota.QuotaBytes)
|
||||
require.Equal(t, "system", quota.Source)
|
||||
}
|
||||
|
||||
func TestGetQuota_NoLimit(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), quota.QuotaBytes)
|
||||
require.Equal(t, "unlimited", quota.Source)
|
||||
}
|
||||
|
||||
func TestGetQuota_UserNotFound(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
_, err := svc.GetQuota(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "get user")
|
||||
}
|
||||
|
||||
func TestGetQuota_GroupRepoError(t *testing.T) {
|
||||
// 分组获取失败时跳过该分组(不影响整体)
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1, SoraStorageQuotaBytes: 0,
|
||||
AllowedGroups: []int64{999}, // 不存在的分组
|
||||
}
|
||||
|
||||
groupRepo := newStubGroupRepoForQuota()
|
||||
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
|
||||
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制
|
||||
}
|
||||
|
||||
// ==================== CheckQuota ====================
|
||||
|
||||
func TestCheckQuota_Sufficient(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 10 * 1024 * 1024,
|
||||
SoraStorageUsedBytes: 3 * 1024 * 1024,
|
||||
}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.CheckQuota(context.Background(), 1, 1024)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckQuota_Exceeded(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 10 * 1024 * 1024,
|
||||
SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满
|
||||
}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.CheckQuota(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "配额不足")
|
||||
}
|
||||
|
||||
func TestCheckQuota_NoLimit(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 0, // 无限制
|
||||
SoraStorageUsedBytes: 1000000000,
|
||||
}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.CheckQuota(context.Background(), 1, 999999999)
|
||||
require.NoError(t, err) // 无限制时始终通过
|
||||
}
|
||||
|
||||
func TestCheckQuota_ExactBoundary(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 1024,
|
||||
SoraStorageUsedBytes: 1024, // 恰好满
|
||||
}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
// 额外 0 字节不超
|
||||
require.NoError(t, svc.CheckQuota(context.Background(), 1, 0))
|
||||
// 额外 1 字节超出
|
||||
require.Error(t, svc.CheckQuota(context.Background(), 1, 1))
|
||||
}
|
||||
|
||||
// ==================== AddUsage ====================
|
||||
|
||||
func TestAddUsage_Success(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.AddUsage(context.Background(), 1, 2048)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestAddUsage_ZeroBytes(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.AddUsage(context.Background(), 1, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
|
||||
}
|
||||
|
||||
func TestAddUsage_NegativeBytes(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.AddUsage(context.Background(), 1, -100)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
|
||||
}
|
||||
|
||||
func TestAddUsage_UserNotFound(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.AddUsage(context.Background(), 999, 1024)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAddUsage_UpdateError(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0}
|
||||
userRepo.updateErr = fmt.Errorf("db error")
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.AddUsage(context.Background(), 1, 1024)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update user quota usage")
|
||||
}
|
||||
|
||||
// ==================== ReleaseUsage ====================
|
||||
|
||||
func TestReleaseUsage_Success(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 1, 1024)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestReleaseUsage_ClampToZero(t *testing.T) {
|
||||
// 释放量大于已用量时,应 clamp 到 0
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 1, 1000)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestReleaseUsage_ZeroBytes(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 1, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
|
||||
}
|
||||
|
||||
func TestReleaseUsage_NegativeBytes(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 1, -50)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
|
||||
}
|
||||
|
||||
func TestReleaseUsage_UserNotFound(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 999, 1024)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReleaseUsage_UpdateError(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
userRepo.updateErr = fmt.Errorf("db error")
|
||||
svc := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
err := svc.ReleaseUsage(context.Background(), 1, 512)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update user quota release")
|
||||
}
|
||||
|
||||
// ==================== GetQuotaFromSettings ====================
|
||||
|
||||
func TestGetQuotaFromSettings_NilSettingService(t *testing.T) {
|
||||
svc := NewSoraQuotaService(nil, nil, nil)
|
||||
require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background()))
|
||||
}
|
||||
|
||||
func TestGetQuotaFromSettings_WithSettings(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
svc := NewSoraQuotaService(nil, nil, settingService)
|
||||
|
||||
require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background()))
|
||||
}
|
||||
|
||||
// ==================== SetUserSoraQuota ====================
|
||||
|
||||
func TestSetUserSoraQuota_Success(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0}
|
||||
|
||||
err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes)
|
||||
}
|
||||
|
||||
func TestSetUserSoraQuota_UserNotFound(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== ParseQuotaBytes ====================
|
||||
|
||||
func TestParseQuotaBytes(t *testing.T) {
|
||||
require.Equal(t, int64(1048576), ParseQuotaBytes("1048576"))
|
||||
require.Equal(t, int64(0), ParseQuotaBytes(""))
|
||||
require.Equal(t, int64(0), ParseQuotaBytes("abc"))
|
||||
require.Equal(t, int64(-1), ParseQuotaBytes("-1"))
|
||||
}
|
||||
|
||||
// ==================== 优先级完整测试 ====================
|
||||
|
||||
func TestQuotaPriority_UserOverridesGroup(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 5 * 1024 * 1024,
|
||||
AllowedGroups: []int64{10},
|
||||
}
|
||||
|
||||
groupRepo := newStubGroupRepoForQuota()
|
||||
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
|
||||
|
||||
svc := NewSoraQuotaService(userRepo, groupRepo, nil)
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user", quota.Source) // 用户级优先
|
||||
require.Equal(t, int64(5*1024*1024), quota.QuotaBytes)
|
||||
}
|
||||
|
||||
func TestQuotaPriority_GroupOverridesSystem(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 0,
|
||||
AllowedGroups: []int64{10},
|
||||
}
|
||||
|
||||
groupRepo := newStubGroupRepoForQuota()
|
||||
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
|
||||
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
|
||||
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "group", quota.Source) // 分组级优先于系统
|
||||
require.Equal(t, int64(20*1024*1024), quota.QuotaBytes)
|
||||
}
|
||||
|
||||
func TestQuotaPriority_FallbackToSystem(t *testing.T) {
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{
|
||||
ID: 1,
|
||||
SoraStorageQuotaBytes: 0,
|
||||
AllowedGroups: []int64{10},
|
||||
}
|
||||
|
||||
groupRepo := newStubGroupRepoForQuota()
|
||||
groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额
|
||||
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
|
||||
svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
|
||||
quota, err := svc.GetQuota(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "system", quota.Source)
|
||||
require.Equal(t, int64(52428800), quota.QuotaBytes)
|
||||
}
|
||||
@ -1,392 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
|
||||
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
|
||||
type SoraS3Storage struct {
|
||||
settingService *SettingService
|
||||
|
||||
mu sync.RWMutex
|
||||
client *s3.Client
|
||||
cfg *SoraS3Settings // 上次加载的配置快照
|
||||
|
||||
healthCheckedAt time.Time
|
||||
healthErr error
|
||||
healthTTL time.Duration
|
||||
}
|
||||
|
||||
const defaultSoraS3HealthTTL = 30 * time.Second
|
||||
|
||||
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
|
||||
type UpstreamDownloadError struct {
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *UpstreamDownloadError) Error() string {
|
||||
if e == nil {
|
||||
return "upstream download failed"
|
||||
}
|
||||
return fmt.Sprintf("upstream returned %d", e.StatusCode)
|
||||
}
|
||||
|
||||
// NewSoraS3Storage 创建 S3 存储服务实例。
|
||||
func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
|
||||
return &SoraS3Storage{
|
||||
settingService: settingService,
|
||||
healthTTL: defaultSoraS3HealthTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled 返回 S3 存储是否已启用且配置有效。
|
||||
func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
|
||||
cfg, err := s.getConfig(ctx)
|
||||
if err != nil || cfg == nil {
|
||||
return false
|
||||
}
|
||||
return cfg.Enabled && cfg.Bucket != ""
|
||||
}
|
||||
|
||||
// getConfig 获取当前 S3 配置(从 settings 表读取)。
|
||||
func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
|
||||
if s.settingService == nil {
|
||||
return nil, fmt.Errorf("setting service not available")
|
||||
}
|
||||
return s.settingService.GetSoraS3Settings(ctx)
|
||||
}
|
||||
|
||||
// getClient 获取或初始化 S3 客户端(带缓存)。
|
||||
// 配置变更时调用 RefreshClient 清除缓存。
|
||||
func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
|
||||
s.mu.RLock()
|
||||
if s.client != nil && s.cfg != nil {
|
||||
client, cfg := s.client, s.cfg
|
||||
s.mu.RUnlock()
|
||||
return client, cfg, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
return s.initClient(ctx)
|
||||
}
|
||||
|
||||
func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if s.client != nil && s.cfg != nil {
|
||||
return s.client, s.cfg, nil
|
||||
}
|
||||
|
||||
cfg, err := s.getConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("load s3 config: %w", err)
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return nil, nil, fmt.Errorf("sora s3 storage is disabled")
|
||||
}
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
client, region, err := buildSoraS3Client(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.client = client
|
||||
s.cfg = cfg
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
|
||||
return client, cfg, nil
|
||||
}
|
||||
|
||||
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
|
||||
// 应在系统设置中 S3 配置变更时调用。
|
||||
func (s *SoraS3Storage) RefreshClient() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.client = nil
|
||||
s.cfg = nil
|
||||
s.healthCheckedAt = time.Time{}
|
||||
s.healthErr = nil
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
|
||||
}
|
||||
|
||||
// TestConnection 测试 S3 连接(HeadBucket)。
|
||||
func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
|
||||
func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.RLock()
|
||||
lastCheck := s.healthCheckedAt
|
||||
lastErr := s.healthErr
|
||||
ttl := s.healthTTL
|
||||
s.mu.RUnlock()
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultSoraS3HealthTTL
|
||||
}
|
||||
if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
|
||||
return lastErr == nil
|
||||
}
|
||||
|
||||
err := s.TestConnection(ctx)
|
||||
s.mu.Lock()
|
||||
s.healthCheckedAt = time.Now()
|
||||
s.healthErr = err
|
||||
s.mu.Unlock()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
|
||||
func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("s3 config is required")
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return fmt.Errorf("sora s3 storage is disabled")
|
||||
}
|
||||
if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
client, _, err := buildSoraS3Client(ctx, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateObjectKey 生成 S3 object key。
|
||||
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
|
||||
func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
}
|
||||
datePath := time.Now().Format("2006/01/02")
|
||||
key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
|
||||
if prefix != "" {
|
||||
prefix = strings.TrimRight(prefix, "/") + "/"
|
||||
key = prefix + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
|
||||
// 返回 S3 object key。
|
||||
func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
// 下载源文件
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
httpClient := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("download from upstream: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 推断文件扩展名
|
||||
ext := fileExtFromURL(sourceURL)
|
||||
if ext == "" {
|
||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
|
||||
objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
|
||||
|
||||
// 检测 Content-Type
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
reader, writer := io.Pipe()
|
||||
uploadErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(uploadErrCh)
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &objectKey,
|
||||
Body: reader,
|
||||
ContentType: &contentType,
|
||||
}
|
||||
if resp.ContentLength >= 0 {
|
||||
input.ContentLength = &resp.ContentLength
|
||||
}
|
||||
_, uploadErr := client.PutObject(ctx, input)
|
||||
uploadErrCh <- uploadErr
|
||||
}()
|
||||
|
||||
written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
|
||||
_ = writer.CloseWithError(copyErr)
|
||||
uploadErr := <-uploadErrCh
|
||||
if copyErr != nil {
|
||||
return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
|
||||
}
|
||||
if uploadErr != nil {
|
||||
return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
|
||||
return objectKey, written, nil
|
||||
}
|
||||
|
||||
func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
|
||||
if cfg == nil {
|
||||
return nil, "", fmt.Errorf("s3 config is required")
|
||||
}
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
// 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
return client, region, nil
|
||||
}
|
||||
|
||||
// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
|
||||
func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
|
||||
if len(objectKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, key := range objectKeys {
|
||||
k := key
|
||||
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &k,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetAccessURL 获取 S3 文件的访问 URL。
|
||||
// CDN URL 优先,否则生成 24h 预签名 URL。
|
||||
func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
|
||||
_, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// CDN URL 优先
|
||||
if cfg.CDNURL != "" {
|
||||
cdnBase := strings.TrimRight(cfg.CDNURL, "/")
|
||||
return cdnBase + "/" + objectKey, nil
|
||||
}
|
||||
|
||||
// 生成 24h 预签名 URL
|
||||
return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
|
||||
}
|
||||
|
||||
// GeneratePresignedURL 生成预签名 URL。
|
||||
func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
presignClient := s3.NewPresignClient(client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &objectKey,
|
||||
}, s3.WithPresignExpires(ttl))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
// GetMediaType 从 object key 推断媒体类型(image/video)。
|
||||
func GetMediaTypeFromKey(objectKey string) string {
|
||||
ext := strings.ToLower(path.Ext(objectKey))
|
||||
switch ext {
|
||||
case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
|
||||
return "video"
|
||||
default:
|
||||
return "image"
|
||||
}
|
||||
}
|
||||
@ -1,263 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ==================== RefreshClient ====================
|
||||
|
||||
func TestRefreshClient(t *testing.T) {
|
||||
s := newS3StorageWithCDN("https://cdn.example.com")
|
||||
require.NotNil(t, s.client)
|
||||
require.NotNil(t, s.cfg)
|
||||
|
||||
s.RefreshClient()
|
||||
require.Nil(t, s.client)
|
||||
require.Nil(t, s.cfg)
|
||||
}
|
||||
|
||||
func TestRefreshClient_AlreadyNil(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
s.RefreshClient() // 不应 panic
|
||||
require.Nil(t, s.client)
|
||||
require.Nil(t, s.cfg)
|
||||
}
|
||||
|
||||
// ==================== GetMediaTypeFromKey ====================
|
||||
|
||||
func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) {
|
||||
for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} {
|
||||
require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) {
|
||||
require.Equal(t, "video", GetMediaTypeFromKey("file.MP4"))
|
||||
require.Equal(t, "video", GetMediaTypeFromKey("file.MOV"))
|
||||
}
|
||||
|
||||
func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) {
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.png"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.jpg"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.gif"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.webp"))
|
||||
}
|
||||
|
||||
func TestGetMediaTypeFromKey_NoExtension(t *testing.T) {
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("path/to/file"))
|
||||
}
|
||||
|
||||
func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) {
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.bin"))
|
||||
require.Equal(t, "image", GetMediaTypeFromKey("file.xyz"))
|
||||
}
|
||||
|
||||
// ==================== Enabled ====================
|
||||
|
||||
func TestEnabled_NilSettingService(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
require.False(t, s.Enabled(context.Background()))
|
||||
}
|
||||
|
||||
func TestEnabled_ConfigDisabled(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "false",
|
||||
SettingKeySoraS3Bucket: "test-bucket",
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
require.False(t, s.Enabled(context.Background()))
|
||||
}
|
||||
|
||||
func TestEnabled_ConfigEnabledWithBucket(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "true",
|
||||
SettingKeySoraS3Bucket: "my-bucket",
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
require.True(t, s.Enabled(context.Background()))
|
||||
}
|
||||
|
||||
func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "true",
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
require.False(t, s.Enabled(context.Background()))
|
||||
}
|
||||
|
||||
// ==================== initClient ====================
|
||||
|
||||
func TestInitClient_Disabled(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "false",
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
|
||||
_, _, err := s.getClient(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "disabled")
|
||||
}
|
||||
|
||||
func TestInitClient_IncompleteConfig(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "true",
|
||||
SettingKeySoraS3Bucket: "test-bucket",
|
||||
// 缺少 access_key_id 和 secret_access_key
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
|
||||
_, _, err := s.getClient(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestInitClient_DefaultRegion(t *testing.T) {
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "true",
|
||||
SettingKeySoraS3Bucket: "test-bucket",
|
||||
SettingKeySoraS3AccessKeyID: "AKID",
|
||||
SettingKeySoraS3SecretAccessKey: "SECRET",
|
||||
// Region 为空 → 默认 us-east-1
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
|
||||
client, cfg, err := s.getClient(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
require.Equal(t, "test-bucket", cfg.Bucket)
|
||||
}
|
||||
|
||||
func TestInitClient_DoubleCheck(t *testing.T) {
|
||||
// 验证双重检查锁定:第二次 getClient 命中缓存
|
||||
settingRepo := newStubSettingRepoForQuota(map[string]string{
|
||||
SettingKeySoraS3Enabled: "true",
|
||||
SettingKeySoraS3Bucket: "test-bucket",
|
||||
SettingKeySoraS3AccessKeyID: "AKID",
|
||||
SettingKeySoraS3SecretAccessKey: "SECRET",
|
||||
})
|
||||
settingService := NewSettingService(settingRepo, &config.Config{})
|
||||
s := NewSoraS3Storage(settingService)
|
||||
|
||||
client1, _, err1 := s.getClient(context.Background())
|
||||
require.NoError(t, err1)
|
||||
client2, _, err2 := s.getClient(context.Background())
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, client1, client2) // 同一客户端实例
|
||||
}
|
||||
|
||||
func TestInitClient_NilSettingService(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
_, _, err := s.getClient(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "setting service not available")
|
||||
}
|
||||
|
||||
// ==================== GenerateObjectKey ====================
|
||||
|
||||
func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
key := s.GenerateObjectKey("", 1, "mp4")
|
||||
require.Contains(t, key, ".mp4")
|
||||
require.True(t, len(key) > 0)
|
||||
}
|
||||
|
||||
func TestGenerateObjectKey_ExtWithDot(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
key := s.GenerateObjectKey("", 1, ".mp4")
|
||||
require.Contains(t, key, ".mp4")
|
||||
// 不应出现 ..mp4
|
||||
require.NotContains(t, key, "..mp4")
|
||||
}
|
||||
|
||||
func TestGenerateObjectKey_WithPrefix(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
key := s.GenerateObjectKey("uploads/", 42, ".png")
|
||||
require.True(t, len(key) > 0)
|
||||
require.Contains(t, key, "uploads/sora/42/")
|
||||
}
|
||||
|
||||
func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
key := s.GenerateObjectKey("uploads", 42, ".png")
|
||||
require.Contains(t, key, "uploads/sora/42/")
|
||||
}
|
||||
|
||||
// ==================== GeneratePresignedURL ====================
|
||||
|
||||
func TestGeneratePresignedURL_GetClientError(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败
|
||||
_, err := s.GeneratePresignedURL(context.Background(), "key", 3600)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== GetAccessURL ====================
|
||||
|
||||
func TestGetAccessURL_CDN(t *testing.T) {
|
||||
s := newS3StorageWithCDN("https://cdn.example.com")
|
||||
url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url)
|
||||
}
|
||||
|
||||
func TestGetAccessURL_CDNTrailingSlash(t *testing.T) {
|
||||
s := newS3StorageWithCDN("https://cdn.example.com/")
|
||||
url, err := s.GetAccessURL(context.Background(), "key.mp4")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/key.mp4", url)
|
||||
}
|
||||
|
||||
func TestGetAccessURL_GetClientError(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
_, err := s.GetAccessURL(context.Background(), "key")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== TestConnection ====================
|
||||
|
||||
func TestTestConnection_GetClientError(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
err := s.TestConnection(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== UploadFromURL ====================
|
||||
|
||||
func TestUploadFromURL_GetClientError(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
_, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== DeleteObjects ====================
|
||||
|
||||
func TestDeleteObjects_EmptyKeys(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
err := s.DeleteObjects(context.Background(), []string{})
|
||||
require.NoError(t, err) // 空列表直接返回
|
||||
}
|
||||
|
||||
func TestDeleteObjects_NilKeys(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
err := s.DeleteObjects(context.Background(), nil)
|
||||
require.NoError(t, err) // nil 列表直接返回
|
||||
}
|
||||
|
||||
func TestDeleteObjects_GetClientError(t *testing.T) {
|
||||
s := NewSoraS3Storage(nil)
|
||||
err := s.DeleteObjects(context.Background(), []string{"key1", "key2"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,149 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
|
||||
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
|
||||
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
|
||||
// 支持流式和非流式响应的直接透传。
|
||||
func (s *SoraGatewayService) forwardToUpstream(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
clientStream bool,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
|
||||
}
|
||||
// 校验 scheme 合法性(仅允许 http/https)
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
|
||||
}
|
||||
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 透传客户端的部分请求头
|
||||
for _, header := range []string{"Accept", "Accept-Encoding"} {
|
||||
if v := c.GetHeader(header); v != "" {
|
||||
upstreamReq.Header.Set(header, v)
|
||||
}
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
|
||||
|
||||
// 获取代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 错误响应处理
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// 非转移错误,直接透传给客户端
|
||||
c.Status(resp.StatusCode)
|
||||
for key, values := range resp.Header {
|
||||
for _, v := range values {
|
||||
c.Writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
if _, err := c.Writer.Write(respBody); err != nil {
|
||||
return nil, fmt.Errorf("write upstream error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 成功响应 — 直接透传
|
||||
c.Status(resp.StatusCode)
|
||||
for key, values := range resp.Header {
|
||||
lower := strings.ToLower(key)
|
||||
// 透传内容相关头部
|
||||
if lower == "content-type" || lower == "transfer-encoding" ||
|
||||
lower == "cache-control" || lower == "x-request-id" {
|
||||
for _, v := range values {
|
||||
c.Writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流式复制响应体
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, err := c.Writer.Write(buf[:n]); err != nil {
|
||||
return nil, fmt.Errorf("stream upstream response write: %w", err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
return nil, fmt.Errorf("copy upstream response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Model: "", // 由调用方填充
|
||||
Stream: clientStream,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
||||
// Antigravity 同样可能有两种缓存键
|
||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||
case PlatformOpenAI, PlatformSora:
|
||||
case PlatformOpenAI:
|
||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||
case PlatformAnthropic:
|
||||
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
|
||||
|
||||
@ -60,7 +60,6 @@ func NewTokenRefreshService(
|
||||
}
|
||||
|
||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||
|
||||
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||
@ -85,18 +84,6 @@ func NewTokenRefreshService(
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
|
||||
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
|
||||
// 需要在 Start() 之前调用
|
||||
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
|
||||
for _, refresher := range s.refreshers {
|
||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||
openaiRefresher.SetSoraAccountRepo(repo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
|
||||
func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) {
|
||||
s.privacyClientFactory = factory
|
||||
|
||||
@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -73,8 +72,6 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||
syncLinkedSora bool
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
@ -90,20 +87,7 @@ func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
|
||||
return OpenAITokenCacheKey(account)
|
||||
}
|
||||
|
||||
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
|
||||
// 用于在 Token 刷新时同步更新 sora_accounts 表
|
||||
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
|
||||
func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
r.soraAccountRepo = repo
|
||||
}
|
||||
|
||||
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
|
||||
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
|
||||
r.syncLinkedSora = enabled
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||
}
|
||||
@ -121,7 +105,6 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
// 刷新成功后,异步同步关联的 Sora 账号
|
||||
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
@ -132,68 +115,5 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||
if r.accountRepo != nil && r.syncLinkedSora {
|
||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
|
||||
// 该方法异步执行,失败只记录日志,不影响主流程
|
||||
//
|
||||
// 同步策略:
|
||||
// 1. 更新 accounts.credentials(主表)
|
||||
// 2. 更新 sora_accounts 扩展表(如果 soraAccountRepo 已设置)
|
||||
//
|
||||
// 超时控制:30 秒,防止数据库阻塞导致 goroutine 泄漏
|
||||
func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
|
||||
// 添加超时控制,防止 goroutine 泄漏
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 1. 查找所有关联的 Sora 账号(限定 platform='sora')
|
||||
soraAccounts, err := r.accountRepo.FindByExtraField(ctx, "linked_openai_account_id", openaiAccountID)
|
||||
if err != nil {
|
||||
log.Printf("[TokenSync] 查找关联 Sora 账号失败: openai_account_id=%d err=%v", openaiAccountID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(soraAccounts) == 0 {
|
||||
// 没有关联的 Sora 账号,直接返回
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 同步更新每个 Sora 账号的双表数据
|
||||
for _, soraAccount := range soraAccounts {
|
||||
// 2.1 更新 accounts.credentials(主表)
|
||||
soraAccount.Credentials["access_token"] = newCredentials["access_token"]
|
||||
soraAccount.Credentials["refresh_token"] = newCredentials["refresh_token"]
|
||||
if expiresAt, ok := newCredentials["expires_at"]; ok {
|
||||
soraAccount.Credentials["expires_at"] = expiresAt
|
||||
}
|
||||
|
||||
if err := r.accountRepo.Update(ctx, &soraAccount); err != nil {
|
||||
log.Printf("[TokenSync] 更新 Sora accounts 表失败: sora_account_id=%d openai_account_id=%d err=%v",
|
||||
soraAccount.ID, openaiAccountID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 2.2 更新 sora_accounts 扩展表(如果仓储已设置)
|
||||
if r.soraAccountRepo != nil {
|
||||
soraUpdates := map[string]any{
|
||||
"access_token": newCredentials["access_token"],
|
||||
"refresh_token": newCredentials["refresh_token"],
|
||||
}
|
||||
if err := r.soraAccountRepo.Upsert(ctx, soraAccount.ID, soraUpdates); err != nil {
|
||||
log.Printf("[TokenSync] 更新 sora_accounts 表失败: account_id=%d openai_account_id=%d err=%v",
|
||||
soraAccount.ID, openaiAccountID, err)
|
||||
// 继续处理其他账号,不中断
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
|
||||
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,12 +242,6 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
|
||||
accType: AccountTypeOAuth,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "sora oauth - cannot refresh directly",
|
||||
platform: PlatformSora,
|
||||
accType: AccountTypeOAuth,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "openai apikey - cannot refresh",
|
||||
platform: PlatformOpenAI,
|
||||
|
||||
@ -110,7 +110,7 @@ type UsageLog struct {
|
||||
ModelMappingChain *string
|
||||
// BillingTier 计费层级标签(per_request/image 模式)
|
||||
BillingTier *string
|
||||
// BillingMode 计费模式:token/image(sora 路径为 nil)
|
||||
// BillingMode 计费模式:token/image
|
||||
BillingMode *string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
|
||||
@ -25,10 +25,6 @@ type User struct {
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值)
|
||||
SoraStorageUsedBytes int64 // Sora 存储已用量
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
|
||||
TotpEnabled bool // 是否启用 TOTP
|
||||
|
||||
@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
@ -54,8 +53,6 @@ func ProvideTokenRefreshService(
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||
@ -281,30 +278,6 @@ func ProvideOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink {
|
||||
return sink
|
||||
}
|
||||
|
||||
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
|
||||
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
}
|
||||
|
||||
func ProvideSoraSDKClient(
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
tokenProvider *OpenAITokenProvider,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
) *SoraSDKClient {
|
||||
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
|
||||
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||
return client
|
||||
}
|
||||
|
||||
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
func buildIdempotencyConfig(cfg *config.Config) IdempotencyConfig {
|
||||
idempotencyCfg := DefaultIdempotencyConfig()
|
||||
if cfg != nil {
|
||||
@ -425,11 +398,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
ProvideSoraSDKClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package soraerror
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -1,47 +0,0 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudflareChallengeResponse(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
|
||||
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
|
||||
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
|
||||
}
|
||||
|
||||
func TestExtractCloudflareRayID(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
|
||||
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
|
||||
}
|
||||
|
||||
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
|
||||
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
|
||||
require.Equal(t, "cf_shield_429", code)
|
||||
require.Equal(t, "rate limited", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
|
||||
require.Equal(t, "unsupported_country_code", code)
|
||||
require.Equal(t, "not available", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
|
||||
require.Equal(t, "", code)
|
||||
require.Equal(t, "plain text", msg)
|
||||
}
|
||||
|
||||
func TestFormatCloudflareChallengeMessage(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
|
||||
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
|
||||
}
|
||||
@ -256,7 +256,6 @@ func shouldBypassEmbeddedFrontend(path string) bool {
|
||||
return strings.HasPrefix(trimmed, "/api/") ||
|
||||
strings.HasPrefix(trimmed, "/v1/") ||
|
||||
strings.HasPrefix(trimmed, "/v1beta/") ||
|
||||
strings.HasPrefix(trimmed, "/sora/") ||
|
||||
strings.HasPrefix(trimmed, "/antigravity/") ||
|
||||
strings.HasPrefix(trimmed, "/setup/") ||
|
||||
trimmed == "/health" ||
|
||||
|
||||
@ -434,7 +434,6 @@ func TestFrontendServer_Middleware(t *testing.T) {
|
||||
"/api/v1/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/sora/v1/models",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
@ -637,7 +636,6 @@ func TestServeEmbeddedFrontend(t *testing.T) {
|
||||
"/api/users",
|
||||
"/v1/models",
|
||||
"/v1beta/chat",
|
||||
"/sora/v1/models",
|
||||
"/antigravity/test",
|
||||
"/setup/init",
|
||||
"/health",
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
normalizeGenerationListResponse,
|
||||
normalizeModelFamiliesResponse
|
||||
} from '../sora'
|
||||
|
||||
describe('sora api normalizers', () => {
|
||||
it('normalizes generation list from data shape', () => {
|
||||
const result = normalizeGenerationListResponse({
|
||||
data: [{ id: 1, status: 'pending' }],
|
||||
total: 9,
|
||||
page: 2
|
||||
})
|
||||
|
||||
expect(result.data).toHaveLength(1)
|
||||
expect(result.total).toBe(9)
|
||||
expect(result.page).toBe(2)
|
||||
})
|
||||
|
||||
it('normalizes generation list from items shape', () => {
|
||||
const result = normalizeGenerationListResponse({
|
||||
items: [{ id: 1, status: 'completed' }],
|
||||
total: 1
|
||||
})
|
||||
|
||||
expect(result.data).toHaveLength(1)
|
||||
expect(result.total).toBe(1)
|
||||
expect(result.page).toBe(1)
|
||||
})
|
||||
|
||||
it('falls back to empty generation list on invalid payload', () => {
|
||||
const result = normalizeGenerationListResponse(null)
|
||||
expect(result).toEqual({ data: [], total: 0, page: 1 })
|
||||
})
|
||||
|
||||
it('normalizes family model payload', () => {
|
||||
const result = normalizeModelFamiliesResponse({
|
||||
data: [
|
||||
{
|
||||
id: 'sora2',
|
||||
name: 'Sora 2',
|
||||
type: 'video',
|
||||
orientations: ['landscape', 'portrait'],
|
||||
durations: [10, 15]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
expect(result).toHaveLength(1)
|
||||
expect(result[0].id).toBe('sora2')
|
||||
expect(result[0].orientations).toEqual(['landscape', 'portrait'])
|
||||
expect(result[0].durations).toEqual([10, 15])
|
||||
})
|
||||
|
||||
it('normalizes legacy flat model list into families', () => {
|
||||
const result = normalizeModelFamiliesResponse({
|
||||
items: [
|
||||
{ id: 'sora2-landscape-10s', type: 'video' },
|
||||
{ id: 'sora2-portrait-15s', type: 'video' },
|
||||
{ id: 'gpt-image-square', type: 'image' }
|
||||
]
|
||||
})
|
||||
|
||||
const sora2 = result.find((m) => m.id === 'sora2')
|
||||
expect(sora2).toBeTruthy()
|
||||
expect(sora2?.orientations).toEqual(['landscape', 'portrait'])
|
||||
expect(sora2?.durations).toEqual([10, 15])
|
||||
|
||||
const image = result.find((m) => m.id === 'gpt-image')
|
||||
expect(image).toBeTruthy()
|
||||
expect(image?.type).toBe('image')
|
||||
expect(image?.orientations).toEqual(['square'])
|
||||
})
|
||||
|
||||
it('falls back to empty families on invalid payload', () => {
|
||||
expect(normalizeModelFamiliesResponse(undefined)).toEqual([])
|
||||
expect(normalizeModelFamiliesResponse({})).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
@ -40,7 +40,6 @@ export interface SystemSettings {
|
||||
hide_ccs_import_button: boolean
|
||||
purchase_subscription_enabled: boolean
|
||||
purchase_subscription_url: string
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
custom_endpoints: CustomEndpoint[]
|
||||
@ -114,7 +113,6 @@ export interface UpdateSettingsRequest {
|
||||
hide_ccs_import_button?: boolean
|
||||
purchase_subscription_enabled?: boolean
|
||||
purchase_subscription_url?: string
|
||||
sora_client_enabled?: boolean
|
||||
backend_mode_enabled?: boolean
|
||||
custom_menu_items?: CustomMenuItem[]
|
||||
custom_endpoints?: CustomEndpoint[]
|
||||
@ -394,142 +392,6 @@ export async function updateBetaPolicySettings(
|
||||
return data
|
||||
}
|
||||
|
||||
// ==================== Sora S3 Settings ====================
|
||||
|
||||
export interface SoraS3Settings {
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key_configured: boolean
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes: number
|
||||
}
|
||||
|
||||
export interface SoraS3Profile {
|
||||
profile_id: string
|
||||
name: string
|
||||
is_active: boolean
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key_configured: boolean
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes: number
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ListSoraS3ProfilesResponse {
|
||||
active_profile_id: string
|
||||
items: SoraS3Profile[]
|
||||
}
|
||||
|
||||
export interface UpdateSoraS3SettingsRequest {
|
||||
profile_id?: string
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes: number
|
||||
}
|
||||
|
||||
export interface CreateSoraS3ProfileRequest {
|
||||
profile_id: string
|
||||
name: string
|
||||
set_active?: boolean
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes: number
|
||||
}
|
||||
|
||||
export interface UpdateSoraS3ProfileRequest {
|
||||
name: string
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes: number
|
||||
}
|
||||
|
||||
export interface TestSoraS3ConnectionRequest {
|
||||
profile_id?: string
|
||||
enabled: boolean
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
cdn_url: string
|
||||
default_storage_quota_bytes?: number
|
||||
}
|
||||
|
||||
export async function getSoraS3Settings(): Promise<SoraS3Settings> {
|
||||
const { data } = await apiClient.get<SoraS3Settings>('/admin/settings/sora-s3')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise<SoraS3Settings> {
|
||||
const { data } = await apiClient.put<SoraS3Settings>('/admin/settings/sora-s3', settings)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function testSoraS3Connection(
|
||||
settings: TestSoraS3ConnectionRequest
|
||||
): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listSoraS3Profiles(): Promise<ListSoraS3ProfilesResponse> {
|
||||
const { data } = await apiClient.get<ListSoraS3ProfilesResponse>('/admin/settings/sora-s3/profiles')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise<SoraS3Profile> {
|
||||
const { data } = await apiClient.post<SoraS3Profile>('/admin/settings/sora-s3/profiles', request)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise<SoraS3Profile> {
|
||||
const { data } = await apiClient.put<SoraS3Profile>(`/admin/settings/sora-s3/profiles/${profileID}`, request)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function deleteSoraS3Profile(profileID: string): Promise<void> {
|
||||
await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`)
|
||||
}
|
||||
|
||||
export async function setActiveSoraS3Profile(profileID: string): Promise<SoraS3Profile> {
|
||||
const { data } = await apiClient.post<SoraS3Profile>(`/admin/settings/sora-s3/profiles/${profileID}/activate`)
|
||||
return data
|
||||
}
|
||||
|
||||
export const settingsAPI = {
|
||||
getSettings,
|
||||
updateSettings,
|
||||
@ -545,15 +407,7 @@ export const settingsAPI = {
|
||||
getRectifierSettings,
|
||||
updateRectifierSettings,
|
||||
getBetaPolicySettings,
|
||||
updateBetaPolicySettings,
|
||||
getSoraS3Settings,
|
||||
updateSoraS3Settings,
|
||||
testSoraS3Connection,
|
||||
listSoraS3Profiles,
|
||||
createSoraS3Profile,
|
||||
updateSoraS3Profile,
|
||||
deleteSoraS3Profile,
|
||||
setActiveSoraS3Profile
|
||||
updateBetaPolicySettings
|
||||
}
|
||||
|
||||
export default settingsAPI
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user