Merge pull request #1463 from touwaeriol/feat/remove-sora
revert: completely remove Sora platform
This commit is contained in:
commit
d757df8a4b
@ -102,12 +102,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||||
soraAccountRepository := repository.NewSoraAccountRepository(db)
|
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
privacyClientFactory := providePrivacyClientFactory()
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
@ -184,12 +183,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
soraS3Storage := service.NewSoraS3Storage(settingService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||||
settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
|
|
||||||
soraGenerationRepository := repository.NewSoraGenerationRepository(db)
|
|
||||||
soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
|
|
||||||
soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
|
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
|
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
@ -223,16 +217,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
|
||||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
|
||||||
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
|
|
||||||
soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
|
|
||||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
||||||
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
@ -243,12 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@ -283,7 +271,6 @@ func provideCleanup(
|
|||||||
opsCleanup *service.OpsCleanupService,
|
opsCleanup *service.OpsCleanupService,
|
||||||
opsScheduledReport *service.OpsScheduledReportService,
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
opsSystemLogSink *service.OpsSystemLogSink,
|
opsSystemLogSink *service.OpsSystemLogSink,
|
||||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
|
||||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
@ -331,12 +318,6 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
{"SoraMediaCleanupService", func() error {
|
|
||||||
if soraMediaCleanup != nil {
|
|
||||||
soraMediaCleanup.Stop()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}},
|
|
||||||
{"OpsAlertEvaluatorService", func() error {
|
{"OpsAlertEvaluatorService", func() error {
|
||||||
if opsAlertEvaluator != nil {
|
if opsAlertEvaluator != nil {
|
||||||
opsAlertEvaluator.Stop()
|
opsAlertEvaluator.Stop()
|
||||||
|
|||||||
@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
&service.OpsCleanupService{},
|
&service.OpsCleanupService{},
|
||||||
&service.OpsScheduledReportService{},
|
&service.OpsScheduledReportService{},
|
||||||
opsSystemLogSinkSvc,
|
opsSystemLogSinkSvc,
|
||||||
&service.SoraMediaCleanupService{},
|
|
||||||
schedulerSnapshotSvc,
|
schedulerSnapshotSvc,
|
||||||
tokenRefreshSvc,
|
tokenRefreshSvc,
|
||||||
accountExpirySvc,
|
accountExpirySvc,
|
||||||
|
|||||||
@ -52,16 +52,6 @@ type Group struct {
|
|||||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||||
// SoraImagePrice360 holds the value of the "sora_image_price_360" field.
|
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
|
||||||
// SoraImagePrice540 holds the value of the "sora_image_price_540" field.
|
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
|
||||||
// SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
|
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
|
||||||
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
|
||||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
|
||||||
// 是否仅允许 Claude Code 客户端
|
// 是否仅允许 Claude Code 客户端
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
@ -196,9 +186,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@ -335,40 +325,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.ImagePrice4k = new(float64)
|
_m.ImagePrice4k = new(float64)
|
||||||
*_m.ImagePrice4k = value.Float64
|
*_m.ImagePrice4k = value.Float64
|
||||||
}
|
}
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraImagePrice360 = new(float64)
|
|
||||||
*_m.SoraImagePrice360 = value.Float64
|
|
||||||
}
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraImagePrice540 = new(float64)
|
|
||||||
*_m.SoraImagePrice540 = value.Float64
|
|
||||||
}
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraVideoPricePerRequest = new(float64)
|
|
||||||
*_m.SoraVideoPricePerRequest = value.Float64
|
|
||||||
}
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraVideoPricePerRequestHd = new(float64)
|
|
||||||
*_m.SoraVideoPricePerRequestHd = value.Float64
|
|
||||||
}
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraStorageQuotaBytes = value.Int64
|
|
||||||
}
|
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||||
@ -590,29 +546,6 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
if v := _m.SoraImagePrice360; v != nil {
|
|
||||||
builder.WriteString("sora_image_price_360=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
|
||||||
}
|
|
||||||
builder.WriteString(", ")
|
|
||||||
if v := _m.SoraImagePrice540; v != nil {
|
|
||||||
builder.WriteString("sora_image_price_540=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
|
||||||
}
|
|
||||||
builder.WriteString(", ")
|
|
||||||
if v := _m.SoraVideoPricePerRequest; v != nil {
|
|
||||||
builder.WriteString("sora_video_price_per_request=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
|
||||||
}
|
|
||||||
builder.WriteString(", ")
|
|
||||||
if v := _m.SoraVideoPricePerRequestHd; v != nil {
|
|
||||||
builder.WriteString("sora_video_price_per_request_hd=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
|
||||||
}
|
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("sora_storage_quota_bytes=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("claude_code_only=")
|
builder.WriteString("claude_code_only=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@ -49,16 +49,6 @@ const (
|
|||||||
FieldImagePrice2k = "image_price_2k"
|
FieldImagePrice2k = "image_price_2k"
|
||||||
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
||||||
FieldImagePrice4k = "image_price_4k"
|
FieldImagePrice4k = "image_price_4k"
|
||||||
// FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
|
|
||||||
FieldSoraImagePrice360 = "sora_image_price_360"
|
|
||||||
// FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
|
|
||||||
FieldSoraImagePrice540 = "sora_image_price_540"
|
|
||||||
// FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
|
|
||||||
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
|
|
||||||
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
|
|
||||||
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
|
|
||||||
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
|
||||||
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
|
||||||
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||||
FieldClaudeCodeOnly = "claude_code_only"
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
@ -175,11 +165,6 @@ var Columns = []string{
|
|||||||
FieldImagePrice1k,
|
FieldImagePrice1k,
|
||||||
FieldImagePrice2k,
|
FieldImagePrice2k,
|
||||||
FieldImagePrice4k,
|
FieldImagePrice4k,
|
||||||
FieldSoraImagePrice360,
|
|
||||||
FieldSoraImagePrice540,
|
|
||||||
FieldSoraVideoPricePerRequest,
|
|
||||||
FieldSoraVideoPricePerRequestHd,
|
|
||||||
FieldSoraStorageQuotaBytes,
|
|
||||||
FieldClaudeCodeOnly,
|
FieldClaudeCodeOnly,
|
||||||
FieldFallbackGroupID,
|
FieldFallbackGroupID,
|
||||||
FieldFallbackGroupIDOnInvalidRequest,
|
FieldFallbackGroupIDOnInvalidRequest,
|
||||||
@ -247,8 +232,6 @@ var (
|
|||||||
SubscriptionTypeValidator func(string) error
|
SubscriptionTypeValidator func(string) error
|
||||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||||
DefaultDefaultValidityDays int
|
DefaultDefaultValidityDays int
|
||||||
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
|
||||||
DefaultSoraStorageQuotaBytes int64
|
|
||||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||||
DefaultClaudeCodeOnly bool
|
DefaultClaudeCodeOnly bool
|
||||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||||
@ -364,31 +347,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
|
|
||||||
func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
|
|
||||||
func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
|
|
||||||
func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
|
|
||||||
func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
|
||||||
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||||
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||||
|
|||||||
@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
|
|
||||||
func SoraImagePrice360(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
|
|
||||||
func SoraImagePrice540(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
|
|
||||||
func SoraVideoPricePerRequest(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
|
|
||||||
func SoraVideoPricePerRequestHd(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
|
||||||
func SoraStorageQuotaBytes(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||||
func ClaudeCodeOnly(v bool) predicate.Group {
|
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
@ -1070,246 +1045,6 @@ func ImagePrice4kNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360EQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360NEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360In(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360GT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360GTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360LT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360LTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360IsNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
|
|
||||||
func SoraImagePrice360NotNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540EQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540NEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540In(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540GT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540GTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540LT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540LTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540IsNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
|
|
||||||
func SoraImagePrice540NotNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestGT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestLT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestIsNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
|
|
||||||
func SoraVideoPricePerRequestNotNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdIsNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
|
|
||||||
func SoraVideoPricePerRequestHdNotNil() predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesGT(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesLT(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||||
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
|
|||||||
@ -258,76 +258,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
|
|
||||||
_c.mutation.SetSoraImagePrice360(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraImagePrice360(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
|
|
||||||
_c.mutation.SetSoraImagePrice540(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraImagePrice540(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
|
|
||||||
_c.mutation.SetSoraVideoPricePerRequest(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraVideoPricePerRequest(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
|
|
||||||
_c.mutation.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraVideoPricePerRequestHd(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
|
|
||||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||||
_c.mutation.SetClaudeCodeOnly(v)
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
@ -645,10 +575,6 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultValidityDays
|
v := group.DefaultDefaultValidityDays
|
||||||
_c.mutation.SetDefaultValidityDays(v)
|
_c.mutation.SetDefaultValidityDays(v)
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
|
||||||
v := group.DefaultSoraStorageQuotaBytes
|
|
||||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
v := group.DefaultClaudeCodeOnly
|
v := group.DefaultClaudeCodeOnly
|
||||||
_c.mutation.SetClaudeCodeOnly(v)
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
@ -737,9 +663,6 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
|
||||||
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||||
}
|
}
|
||||||
@ -867,26 +790,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||||
_node.ImagePrice4k = &value
|
_node.ImagePrice4k = &value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.SoraImagePrice360(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
|
||||||
_node.SoraImagePrice360 = &value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.SoraImagePrice540(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
|
||||||
_node.SoraImagePrice540 = &value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
|
||||||
_node.SoraVideoPricePerRequest = &value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
|
||||||
_node.SoraVideoPricePerRequestHd = &value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
_node.SoraStorageQuotaBytes = value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
_node.ClaudeCodeOnly = value
|
_node.ClaudeCodeOnly = value
|
||||||
@ -1379,120 +1282,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSoraImagePrice360, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSoraImagePrice360)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
|
|
||||||
u.Add(group.FieldSoraImagePrice360, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
|
|
||||||
u.SetNull(group.FieldSoraImagePrice360)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSoraImagePrice540, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSoraImagePrice540)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
|
|
||||||
u.Add(group.FieldSoraImagePrice540, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
|
|
||||||
u.SetNull(group.FieldSoraImagePrice540)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSoraVideoPricePerRequest, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSoraVideoPricePerRequest)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
|
|
||||||
u.Add(group.FieldSoraVideoPricePerRequest, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
|
|
||||||
u.SetNull(group.FieldSoraVideoPricePerRequest)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSoraVideoPricePerRequestHd, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
|
|
||||||
u.Add(group.FieldSoraVideoPricePerRequestHd, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
|
|
||||||
u.SetNull(group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSoraStorageQuotaBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSoraStorageQuotaBytes)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
|
||||||
u.Add(group.FieldSoraStorageQuotaBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||||
u.Set(group.FieldClaudeCodeOnly, v)
|
u.Set(group.FieldClaudeCodeOnly, v)
|
||||||
@ -2054,139 +1843,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraImagePrice360(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraImagePrice360(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraImagePrice360()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraImagePrice360()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraImagePrice540(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraImagePrice540(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraImagePrice540()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraImagePrice540()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraVideoPricePerRequest(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraVideoPricePerRequest(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraVideoPricePerRequest()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraVideoPricePerRequest()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraVideoPricePerRequestHd(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraVideoPricePerRequestHd()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraVideoPricePerRequestHd()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraStorageQuotaBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
@ -2944,139 +2600,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraImagePrice360(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraImagePrice360(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraImagePrice360()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraImagePrice360()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraImagePrice540(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraImagePrice540(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraImagePrice540()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraImagePrice540()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraVideoPricePerRequest(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraVideoPricePerRequest(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraVideoPricePerRequest()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraVideoPricePerRequest()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraVideoPricePerRequestHd(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraVideoPricePerRequestHd()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.ClearSoraVideoPricePerRequestHd()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.AddSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSoraStorageQuotaBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
|||||||
@ -355,135 +355,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.ResetSoraImagePrice360()
|
|
||||||
_u.mutation.SetSoraImagePrice360(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraImagePrice360(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.AddSoraImagePrice360(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
|
|
||||||
_u.mutation.ClearSoraImagePrice360()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.ResetSoraImagePrice540()
|
|
||||||
_u.mutation.SetSoraImagePrice540(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraImagePrice540(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.AddSoraImagePrice540(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
|
|
||||||
_u.mutation.ClearSoraImagePrice540()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.ResetSoraVideoPricePerRequest()
|
|
||||||
_u.mutation.SetSoraVideoPricePerRequest(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraVideoPricePerRequest(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.AddSoraVideoPricePerRequest(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
|
|
||||||
_u.mutation.ClearSoraVideoPricePerRequest()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.ResetSoraVideoPricePerRequestHd()
|
|
||||||
_u.mutation.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraVideoPricePerRequestHd(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
|
|
||||||
_u.mutation.AddSoraVideoPricePerRequestHd(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
|
|
||||||
_u.mutation.ClearSoraVideoPricePerRequestHd()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
|
||||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
|
||||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
|
||||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||||
_u.mutation.SetClaudeCodeOnly(v)
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
@ -1082,48 +953,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.ImagePrice4kCleared() {
|
if _u.mutation.ImagePrice4kCleared() {
|
||||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SoraImagePrice360(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraImagePrice360Cleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraImagePrice540(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraImagePrice540Cleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraVideoPricePerRequestCleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
@ -1817,135 +1646,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.ResetSoraImagePrice360()
|
|
||||||
_u.mutation.SetSoraImagePrice360(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraImagePrice360(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.AddSoraImagePrice360(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
|
|
||||||
_u.mutation.ClearSoraImagePrice360()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.ResetSoraImagePrice540()
|
|
||||||
_u.mutation.SetSoraImagePrice540(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraImagePrice540(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.AddSoraImagePrice540(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
|
|
||||||
_u.mutation.ClearSoraImagePrice540()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.ResetSoraVideoPricePerRequest()
|
|
||||||
_u.mutation.SetSoraVideoPricePerRequest(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraVideoPricePerRequest(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.AddSoraVideoPricePerRequest(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
|
|
||||||
_u.mutation.ClearSoraVideoPricePerRequest()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.ResetSoraVideoPricePerRequestHd()
|
|
||||||
_u.mutation.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraVideoPricePerRequestHd(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
|
|
||||||
_u.mutation.AddSoraVideoPricePerRequestHd(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
|
|
||||||
_u.mutation.ClearSoraVideoPricePerRequestHd()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
|
||||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
|
||||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
|
||||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||||
_u.mutation.SetClaudeCodeOnly(v)
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
@ -2574,48 +2274,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.ImagePrice4kCleared() {
|
if _u.mutation.ImagePrice4kCleared() {
|
||||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SoraImagePrice360(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraImagePrice360Cleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraImagePrice540(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraImagePrice540Cleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraVideoPricePerRequestCleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -395,11 +395,6 @@ var (
|
|||||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
|
||||||
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
|
||||||
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
|
||||||
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
|
||||||
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
|
||||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
||||||
@ -447,7 +442,7 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "group_sort_order",
|
Name: "group_sort_order",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{GroupsColumns[30]},
|
Columns: []*schema.Column{GroupsColumns[25]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -770,7 +765,6 @@ var (
|
|||||||
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||||
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
|
|
||||||
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "api_key_id", Type: field.TypeInt64},
|
{Name: "api_key_id", Type: field.TypeInt64},
|
||||||
@ -787,31 +781,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@ -820,32 +814,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@ -865,17 +859,17 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
|
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
|
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id_created_at",
|
Name: "usagelog_group_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
|
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -896,8 +890,6 @@ var (
|
|||||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||||
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
|
||||||
{Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
|
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
UsersTable = &schema.Table{
|
UsersTable = &schema.Table{
|
||||||
|
|||||||
@ -8230,16 +8230,6 @@ type GroupMutation struct {
|
|||||||
addimage_price_2k *float64
|
addimage_price_2k *float64
|
||||||
image_price_4k *float64
|
image_price_4k *float64
|
||||||
addimage_price_4k *float64
|
addimage_price_4k *float64
|
||||||
sora_image_price_360 *float64
|
|
||||||
addsora_image_price_360 *float64
|
|
||||||
sora_image_price_540 *float64
|
|
||||||
addsora_image_price_540 *float64
|
|
||||||
sora_video_price_per_request *float64
|
|
||||||
addsora_video_price_per_request *float64
|
|
||||||
sora_video_price_per_request_hd *float64
|
|
||||||
addsora_video_price_per_request_hd *float64
|
|
||||||
sora_storage_quota_bytes *int64
|
|
||||||
addsora_storage_quota_bytes *int64
|
|
||||||
claude_code_only *bool
|
claude_code_only *bool
|
||||||
fallback_group_id *int64
|
fallback_group_id *int64
|
||||||
addfallback_group_id *int64
|
addfallback_group_id *int64
|
||||||
@ -9260,342 +9250,6 @@ func (m *GroupMutation) ResetImagePrice4k() {
|
|||||||
delete(m.clearedFields, group.FieldImagePrice4k)
|
delete(m.clearedFields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
|
|
||||||
func (m *GroupMutation) SetSoraImagePrice360(f float64) {
|
|
||||||
m.sora_image_price_360 = &f
|
|
||||||
m.addsora_image_price_360 = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation.
|
|
||||||
func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) {
|
|
||||||
v := m.sora_image_price_360
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraImagePrice360, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice360 adds f to the "sora_image_price_360" field.
|
|
||||||
func (m *GroupMutation) AddSoraImagePrice360(f float64) {
|
|
||||||
if m.addsora_image_price_360 != nil {
|
|
||||||
*m.addsora_image_price_360 += f
|
|
||||||
} else {
|
|
||||||
m.addsora_image_price_360 = &f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation.
|
|
||||||
func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) {
|
|
||||||
v := m.addsora_image_price_360
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
|
|
||||||
func (m *GroupMutation) ClearSoraImagePrice360() {
|
|
||||||
m.sora_image_price_360 = nil
|
|
||||||
m.addsora_image_price_360 = nil
|
|
||||||
m.clearedFields[group.FieldSoraImagePrice360] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation.
|
|
||||||
func (m *GroupMutation) SoraImagePrice360Cleared() bool {
|
|
||||||
_, ok := m.clearedFields[group.FieldSoraImagePrice360]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field.
|
|
||||||
func (m *GroupMutation) ResetSoraImagePrice360() {
|
|
||||||
m.sora_image_price_360 = nil
|
|
||||||
m.addsora_image_price_360 = nil
|
|
||||||
delete(m.clearedFields, group.FieldSoraImagePrice360)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
|
|
||||||
func (m *GroupMutation) SetSoraImagePrice540(f float64) {
|
|
||||||
m.sora_image_price_540 = &f
|
|
||||||
m.addsora_image_price_540 = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation.
|
|
||||||
func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) {
|
|
||||||
v := m.sora_image_price_540
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraImagePrice540, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraImagePrice540 adds f to the "sora_image_price_540" field.
|
|
||||||
func (m *GroupMutation) AddSoraImagePrice540(f float64) {
|
|
||||||
if m.addsora_image_price_540 != nil {
|
|
||||||
*m.addsora_image_price_540 += f
|
|
||||||
} else {
|
|
||||||
m.addsora_image_price_540 = &f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation.
|
|
||||||
func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) {
|
|
||||||
v := m.addsora_image_price_540
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
|
|
||||||
func (m *GroupMutation) ClearSoraImagePrice540() {
|
|
||||||
m.sora_image_price_540 = nil
|
|
||||||
m.addsora_image_price_540 = nil
|
|
||||||
m.clearedFields[group.FieldSoraImagePrice540] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation.
|
|
||||||
func (m *GroupMutation) SoraImagePrice540Cleared() bool {
|
|
||||||
_, ok := m.clearedFields[group.FieldSoraImagePrice540]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field.
|
|
||||||
func (m *GroupMutation) ResetSoraImagePrice540() {
|
|
||||||
m.sora_image_price_540 = nil
|
|
||||||
m.addsora_image_price_540 = nil
|
|
||||||
delete(m.clearedFields, group.FieldSoraImagePrice540)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
|
|
||||||
func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) {
|
|
||||||
m.sora_video_price_per_request = &f
|
|
||||||
m.addsora_video_price_per_request = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation.
|
|
||||||
func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) {
|
|
||||||
v := m.sora_video_price_per_request
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraVideoPricePerRequest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field.
|
|
||||||
func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) {
|
|
||||||
if m.addsora_video_price_per_request != nil {
|
|
||||||
*m.addsora_video_price_per_request += f
|
|
||||||
} else {
|
|
||||||
m.addsora_video_price_per_request = &f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation.
|
|
||||||
func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) {
|
|
||||||
v := m.addsora_video_price_per_request
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
|
|
||||||
func (m *GroupMutation) ClearSoraVideoPricePerRequest() {
|
|
||||||
m.sora_video_price_per_request = nil
|
|
||||||
m.addsora_video_price_per_request = nil
|
|
||||||
m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation.
|
|
||||||
func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool {
|
|
||||||
_, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field.
|
|
||||||
func (m *GroupMutation) ResetSoraVideoPricePerRequest() {
|
|
||||||
m.sora_video_price_per_request = nil
|
|
||||||
m.addsora_video_price_per_request = nil
|
|
||||||
delete(m.clearedFields, group.FieldSoraVideoPricePerRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
|
|
||||||
func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) {
|
|
||||||
m.sora_video_price_per_request_hd = &f
|
|
||||||
m.addsora_video_price_per_request_hd = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation.
|
|
||||||
func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) {
|
|
||||||
v := m.sora_video_price_per_request_hd
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraVideoPricePerRequestHd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) {
|
|
||||||
if m.addsora_video_price_per_request_hd != nil {
|
|
||||||
*m.addsora_video_price_per_request_hd += f
|
|
||||||
} else {
|
|
||||||
m.addsora_video_price_per_request_hd = &f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation.
|
|
||||||
func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) {
|
|
||||||
v := m.addsora_video_price_per_request_hd
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
|
|
||||||
func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() {
|
|
||||||
m.sora_video_price_per_request_hd = nil
|
|
||||||
m.addsora_video_price_per_request_hd = nil
|
|
||||||
m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation.
|
|
||||||
func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool {
|
|
||||||
_, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field.
|
|
||||||
func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() {
|
|
||||||
m.sora_video_price_per_request_hd = nil
|
|
||||||
m.addsora_video_price_per_request_hd = nil
|
|
||||||
delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) {
|
|
||||||
m.sora_storage_quota_bytes = &i
|
|
||||||
m.addsora_storage_quota_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
|
|
||||||
func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
|
|
||||||
v := m.sora_storage_quota_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraStorageQuotaBytes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) {
|
|
||||||
if m.addsora_storage_quota_bytes != nil {
|
|
||||||
*m.addsora_storage_quota_bytes += i
|
|
||||||
} else {
|
|
||||||
m.addsora_storage_quota_bytes = &i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
|
|
||||||
func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
|
|
||||||
v := m.addsora_storage_quota_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *GroupMutation) ResetSoraStorageQuotaBytes() {
|
|
||||||
m.sora_storage_quota_bytes = nil
|
|
||||||
m.addsora_storage_quota_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
|
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
|
||||||
m.claude_code_only = &b
|
m.claude_code_only = &b
|
||||||
@ -10502,7 +10156,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 34)
|
fields := make([]string, 0, 29)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@ -10554,21 +10208,6 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.image_price_4k != nil {
|
if m.image_price_4k != nil {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
if m.sora_image_price_360 != nil {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice360)
|
|
||||||
}
|
|
||||||
if m.sora_image_price_540 != nil {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice540)
|
|
||||||
}
|
|
||||||
if m.sora_video_price_per_request != nil {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequest)
|
|
||||||
}
|
|
||||||
if m.sora_video_price_per_request_hd != nil {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
}
|
|
||||||
if m.sora_storage_quota_bytes != nil {
|
|
||||||
fields = append(fields, group.FieldSoraStorageQuotaBytes)
|
|
||||||
}
|
|
||||||
if m.claude_code_only != nil {
|
if m.claude_code_only != nil {
|
||||||
fields = append(fields, group.FieldClaudeCodeOnly)
|
fields = append(fields, group.FieldClaudeCodeOnly)
|
||||||
}
|
}
|
||||||
@ -10647,16 +10286,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ImagePrice2k()
|
return m.ImagePrice2k()
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.ImagePrice4k()
|
return m.ImagePrice4k()
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
return m.SoraImagePrice360()
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
return m.SoraImagePrice540()
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
return m.SoraVideoPricePerRequest()
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
return m.SoraVideoPricePerRequestHd()
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.SoraStorageQuotaBytes()
|
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
return m.ClaudeCodeOnly()
|
return m.ClaudeCodeOnly()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
@ -10724,16 +10353,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldImagePrice2k(ctx)
|
return m.OldImagePrice2k(ctx)
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.OldImagePrice4k(ctx)
|
return m.OldImagePrice4k(ctx)
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
return m.OldSoraImagePrice360(ctx)
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
return m.OldSoraImagePrice540(ctx)
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
return m.OldSoraVideoPricePerRequest(ctx)
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
return m.OldSoraVideoPricePerRequestHd(ctx)
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.OldSoraStorageQuotaBytes(ctx)
|
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
return m.OldClaudeCodeOnly(ctx)
|
return m.OldClaudeCodeOnly(ctx)
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
@ -10886,41 +10505,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetImagePrice4k(v)
|
m.SetImagePrice4k(v)
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraImagePrice360(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraImagePrice540(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraVideoPricePerRequest(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraVideoPricePerRequestHd(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraStorageQuotaBytes(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
v, ok := value.(bool)
|
v, ok := value.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -11037,21 +10621,6 @@ func (m *GroupMutation) AddedFields() []string {
|
|||||||
if m.addimage_price_4k != nil {
|
if m.addimage_price_4k != nil {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
if m.addsora_image_price_360 != nil {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice360)
|
|
||||||
}
|
|
||||||
if m.addsora_image_price_540 != nil {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice540)
|
|
||||||
}
|
|
||||||
if m.addsora_video_price_per_request != nil {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequest)
|
|
||||||
}
|
|
||||||
if m.addsora_video_price_per_request_hd != nil {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
}
|
|
||||||
if m.addsora_storage_quota_bytes != nil {
|
|
||||||
fields = append(fields, group.FieldSoraStorageQuotaBytes)
|
|
||||||
}
|
|
||||||
if m.addfallback_group_id != nil {
|
if m.addfallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
@ -11085,16 +10654,6 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedImagePrice2k()
|
return m.AddedImagePrice2k()
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.AddedImagePrice4k()
|
return m.AddedImagePrice4k()
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
return m.AddedSoraImagePrice360()
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
return m.AddedSoraImagePrice540()
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
return m.AddedSoraVideoPricePerRequest()
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
return m.AddedSoraVideoPricePerRequestHd()
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.AddedSoraStorageQuotaBytes()
|
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.AddedFallbackGroupID()
|
return m.AddedFallbackGroupID()
|
||||||
case group.FieldFallbackGroupIDOnInvalidRequest:
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
@ -11166,41 +10725,6 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddImagePrice4k(v)
|
m.AddImagePrice4k(v)
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraImagePrice360(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraImagePrice540(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraVideoPricePerRequest(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
v, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraVideoPricePerRequestHd(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraStorageQuotaBytes(v)
|
|
||||||
return nil
|
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
v, ok := value.(int64)
|
v, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -11254,18 +10778,6 @@ func (m *GroupMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(group.FieldImagePrice4k) {
|
if m.FieldCleared(group.FieldImagePrice4k) {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
if m.FieldCleared(group.FieldSoraImagePrice360) {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice360)
|
|
||||||
}
|
|
||||||
if m.FieldCleared(group.FieldSoraImagePrice540) {
|
|
||||||
fields = append(fields, group.FieldSoraImagePrice540)
|
|
||||||
}
|
|
||||||
if m.FieldCleared(group.FieldSoraVideoPricePerRequest) {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequest)
|
|
||||||
}
|
|
||||||
if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) {
|
|
||||||
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
|
|
||||||
}
|
|
||||||
if m.FieldCleared(group.FieldFallbackGroupID) {
|
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
@ -11313,18 +10825,6 @@ func (m *GroupMutation) ClearField(name string) error {
|
|||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
m.ClearImagePrice4k()
|
m.ClearImagePrice4k()
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
m.ClearSoraImagePrice360()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
m.ClearSoraImagePrice540()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
m.ClearSoraVideoPricePerRequest()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
m.ClearSoraVideoPricePerRequestHd()
|
|
||||||
return nil
|
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ClearFallbackGroupID()
|
m.ClearFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
@ -11393,21 +10893,6 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
m.ResetImagePrice4k()
|
m.ResetImagePrice4k()
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSoraImagePrice360:
|
|
||||||
m.ResetSoraImagePrice360()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraImagePrice540:
|
|
||||||
m.ResetSoraImagePrice540()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequest:
|
|
||||||
m.ResetSoraVideoPricePerRequest()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraVideoPricePerRequestHd:
|
|
||||||
m.ResetSoraVideoPricePerRequestHd()
|
|
||||||
return nil
|
|
||||||
case group.FieldSoraStorageQuotaBytes:
|
|
||||||
m.ResetSoraStorageQuotaBytes()
|
|
||||||
return nil
|
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
m.ResetClaudeCodeOnly()
|
m.ResetClaudeCodeOnly()
|
||||||
return nil
|
return nil
|
||||||
@ -19770,7 +19255,6 @@ type UsageLogMutation struct {
|
|||||||
image_count *int
|
image_count *int
|
||||||
addimage_count *int
|
addimage_count *int
|
||||||
image_size *string
|
image_size *string
|
||||||
media_type *string
|
|
||||||
cache_ttl_overridden *bool
|
cache_ttl_overridden *bool
|
||||||
created_at *time.Time
|
created_at *time.Time
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
@ -21713,55 +21197,6 @@ func (m *UsageLogMutation) ResetImageSize() {
|
|||||||
delete(m.clearedFields, usagelog.FieldImageSize)
|
delete(m.clearedFields, usagelog.FieldImageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (m *UsageLogMutation) SetMediaType(s string) {
|
|
||||||
m.media_type = &s
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaType returns the value of the "media_type" field in the mutation.
|
|
||||||
func (m *UsageLogMutation) MediaType() (r string, exists bool) {
|
|
||||||
v := m.media_type
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldMediaType returns the old "media_type" field's value of the UsageLog entity.
|
|
||||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldMediaType is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldMediaType requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldMediaType: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.MediaType, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (m *UsageLogMutation) ClearMediaType() {
|
|
||||||
m.media_type = nil
|
|
||||||
m.clearedFields[usagelog.FieldMediaType] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeCleared returns if the "media_type" field was cleared in this mutation.
|
|
||||||
func (m *UsageLogMutation) MediaTypeCleared() bool {
|
|
||||||
_, ok := m.clearedFields[usagelog.FieldMediaType]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetMediaType resets all changes to the "media_type" field.
|
|
||||||
func (m *UsageLogMutation) ResetMediaType() {
|
|
||||||
m.media_type = nil
|
|
||||||
delete(m.clearedFields, usagelog.FieldMediaType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
||||||
m.cache_ttl_overridden = &b
|
m.cache_ttl_overridden = &b
|
||||||
@ -22003,7 +21438,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 38)
|
fields := make([]string, 0, 37)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@ -22109,9 +21544,6 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.image_size != nil {
|
if m.image_size != nil {
|
||||||
fields = append(fields, usagelog.FieldImageSize)
|
fields = append(fields, usagelog.FieldImageSize)
|
||||||
}
|
}
|
||||||
if m.media_type != nil {
|
|
||||||
fields = append(fields, usagelog.FieldMediaType)
|
|
||||||
}
|
|
||||||
if m.cache_ttl_overridden != nil {
|
if m.cache_ttl_overridden != nil {
|
||||||
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
||||||
}
|
}
|
||||||
@ -22196,8 +21628,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ImageCount()
|
return m.ImageCount()
|
||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
return m.ImageSize()
|
return m.ImageSize()
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
return m.MediaType()
|
|
||||||
case usagelog.FieldCacheTTLOverridden:
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
return m.CacheTTLOverridden()
|
return m.CacheTTLOverridden()
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
@ -22281,8 +21711,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldImageCount(ctx)
|
return m.OldImageCount(ctx)
|
||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
return m.OldImageSize(ctx)
|
return m.OldImageSize(ctx)
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
return m.OldMediaType(ctx)
|
|
||||||
case usagelog.FieldCacheTTLOverridden:
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
return m.OldCacheTTLOverridden(ctx)
|
return m.OldCacheTTLOverridden(ctx)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
@ -22541,13 +21969,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetImageSize(v)
|
m.SetImageSize(v)
|
||||||
return nil
|
return nil
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
v, ok := value.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetMediaType(v)
|
|
||||||
return nil
|
|
||||||
case usagelog.FieldCacheTTLOverridden:
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
v, ok := value.(bool)
|
v, ok := value.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -22865,9 +22286,6 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(usagelog.FieldImageSize) {
|
if m.FieldCleared(usagelog.FieldImageSize) {
|
||||||
fields = append(fields, usagelog.FieldImageSize)
|
fields = append(fields, usagelog.FieldImageSize)
|
||||||
}
|
}
|
||||||
if m.FieldCleared(usagelog.FieldMediaType) {
|
|
||||||
fields = append(fields, usagelog.FieldMediaType)
|
|
||||||
}
|
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22924,9 +22342,6 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
|||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
m.ClearImageSize()
|
m.ClearImageSize()
|
||||||
return nil
|
return nil
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
m.ClearMediaType()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown UsageLog nullable field %s", name)
|
return fmt.Errorf("unknown UsageLog nullable field %s", name)
|
||||||
}
|
}
|
||||||
@ -23040,9 +22455,6 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
m.ResetImageSize()
|
m.ResetImageSize()
|
||||||
return nil
|
return nil
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
m.ResetMediaType()
|
|
||||||
return nil
|
|
||||||
case usagelog.FieldCacheTTLOverridden:
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
m.ResetCacheTTLOverridden()
|
m.ResetCacheTTLOverridden()
|
||||||
return nil
|
return nil
|
||||||
@ -23221,10 +22633,6 @@ type UserMutation struct {
|
|||||||
totp_secret_encrypted *string
|
totp_secret_encrypted *string
|
||||||
totp_enabled *bool
|
totp_enabled *bool
|
||||||
totp_enabled_at *time.Time
|
totp_enabled_at *time.Time
|
||||||
sora_storage_quota_bytes *int64
|
|
||||||
addsora_storage_quota_bytes *int64
|
|
||||||
sora_storage_used_bytes *int64
|
|
||||||
addsora_storage_used_bytes *int64
|
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@ -23939,118 +23347,6 @@ func (m *UserMutation) ResetTotpEnabledAt() {
|
|||||||
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) {
|
|
||||||
m.sora_storage_quota_bytes = &i
|
|
||||||
m.addsora_storage_quota_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
|
|
||||||
func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
|
|
||||||
v := m.sora_storage_quota_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity.
|
|
||||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraStorageQuotaBytes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) {
|
|
||||||
if m.addsora_storage_quota_bytes != nil {
|
|
||||||
*m.addsora_storage_quota_bytes += i
|
|
||||||
} else {
|
|
||||||
m.addsora_storage_quota_bytes = &i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
|
|
||||||
func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
|
|
||||||
v := m.addsora_storage_quota_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
|
|
||||||
func (m *UserMutation) ResetSoraStorageQuotaBytes() {
|
|
||||||
m.sora_storage_quota_bytes = nil
|
|
||||||
m.addsora_storage_quota_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (m *UserMutation) SetSoraStorageUsedBytes(i int64) {
|
|
||||||
m.sora_storage_used_bytes = &i
|
|
||||||
m.addsora_storage_used_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation.
|
|
||||||
func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) {
|
|
||||||
v := m.sora_storage_used_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity.
|
|
||||||
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SoraStorageUsedBytes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field.
|
|
||||||
func (m *UserMutation) AddSoraStorageUsedBytes(i int64) {
|
|
||||||
if m.addsora_storage_used_bytes != nil {
|
|
||||||
*m.addsora_storage_used_bytes += i
|
|
||||||
} else {
|
|
||||||
m.addsora_storage_used_bytes = &i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation.
|
|
||||||
func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) {
|
|
||||||
v := m.addsora_storage_used_bytes
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field.
|
|
||||||
func (m *UserMutation) ResetSoraStorageUsedBytes() {
|
|
||||||
m.sora_storage_used_bytes = nil
|
|
||||||
m.addsora_storage_used_bytes = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@ -24571,7 +23867,7 @@ func (m *UserMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UserMutation) Fields() []string {
|
func (m *UserMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 16)
|
fields := make([]string, 0, 14)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, user.FieldCreatedAt)
|
fields = append(fields, user.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@ -24614,12 +23910,6 @@ func (m *UserMutation) Fields() []string {
|
|||||||
if m.totp_enabled_at != nil {
|
if m.totp_enabled_at != nil {
|
||||||
fields = append(fields, user.FieldTotpEnabledAt)
|
fields = append(fields, user.FieldTotpEnabledAt)
|
||||||
}
|
}
|
||||||
if m.sora_storage_quota_bytes != nil {
|
|
||||||
fields = append(fields, user.FieldSoraStorageQuotaBytes)
|
|
||||||
}
|
|
||||||
if m.sora_storage_used_bytes != nil {
|
|
||||||
fields = append(fields, user.FieldSoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24656,10 +23946,6 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.TotpEnabled()
|
return m.TotpEnabled()
|
||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
return m.TotpEnabledAt()
|
return m.TotpEnabledAt()
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.SoraStorageQuotaBytes()
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
return m.SoraStorageUsedBytes()
|
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -24697,10 +23983,6 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
|||||||
return m.OldTotpEnabled(ctx)
|
return m.OldTotpEnabled(ctx)
|
||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
return m.OldTotpEnabledAt(ctx)
|
return m.OldTotpEnabledAt(ctx)
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.OldSoraStorageQuotaBytes(ctx)
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
return m.OldSoraStorageUsedBytes(ctx)
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown User field %s", name)
|
return nil, fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@ -24808,20 +24090,6 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetTotpEnabledAt(v)
|
m.SetTotpEnabledAt(v)
|
||||||
return nil
|
return nil
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraStorageQuotaBytes(v)
|
|
||||||
return nil
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSoraStorageUsedBytes(v)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@ -24836,12 +24104,6 @@ func (m *UserMutation) AddedFields() []string {
|
|||||||
if m.addconcurrency != nil {
|
if m.addconcurrency != nil {
|
||||||
fields = append(fields, user.FieldConcurrency)
|
fields = append(fields, user.FieldConcurrency)
|
||||||
}
|
}
|
||||||
if m.addsora_storage_quota_bytes != nil {
|
|
||||||
fields = append(fields, user.FieldSoraStorageQuotaBytes)
|
|
||||||
}
|
|
||||||
if m.addsora_storage_used_bytes != nil {
|
|
||||||
fields = append(fields, user.FieldSoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24854,10 +24116,6 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedBalance()
|
return m.AddedBalance()
|
||||||
case user.FieldConcurrency:
|
case user.FieldConcurrency:
|
||||||
return m.AddedConcurrency()
|
return m.AddedConcurrency()
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
return m.AddedSoraStorageQuotaBytes()
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
return m.AddedSoraStorageUsedBytes()
|
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -24881,20 +24139,6 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddConcurrency(v)
|
m.AddConcurrency(v)
|
||||||
return nil
|
return nil
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraStorageQuotaBytes(v)
|
|
||||||
return nil
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
v, ok := value.(int64)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.AddSoraStorageUsedBytes(v)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User numeric field %s", name)
|
return fmt.Errorf("unknown User numeric field %s", name)
|
||||||
}
|
}
|
||||||
@ -24985,12 +24229,6 @@ func (m *UserMutation) ResetField(name string) error {
|
|||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
m.ResetTotpEnabledAt()
|
m.ResetTotpEnabledAt()
|
||||||
return nil
|
return nil
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
m.ResetSoraStorageQuotaBytes()
|
|
||||||
return nil
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
m.ResetSoraStorageUsedBytes()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -430,44 +430,40 @@ func init() {
|
|||||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||||
// groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
|
||||||
groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
|
|
||||||
// group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
|
||||||
group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
|
|
||||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||||
groupDescClaudeCodeOnly := groupFields[19].Descriptor()
|
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||||
groupDescModelRoutingEnabled := groupFields[23].Descriptor()
|
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||||
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||||
groupDescMcpXMLInject := groupFields[24].Descriptor()
|
groupDescMcpXMLInject := groupFields[19].Descriptor()
|
||||||
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||||
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||||
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||||
groupDescSupportedModelScopes := groupFields[25].Descriptor()
|
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||||
groupDescSortOrder := groupFields[26].Descriptor()
|
groupDescSortOrder := groupFields[21].Descriptor()
|
||||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||||
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
|
||||||
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
|
groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
|
||||||
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
|
||||||
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
|
||||||
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
|
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
|
||||||
groupDescRequireOauthOnly := groupFields[28].Descriptor()
|
groupDescRequireOauthOnly := groupFields[23].Descriptor()
|
||||||
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
|
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
|
||||||
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
|
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
|
||||||
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
|
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
|
||||||
groupDescRequirePrivacySet := groupFields[29].Descriptor()
|
groupDescRequirePrivacySet := groupFields[24].Descriptor()
|
||||||
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
|
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
|
||||||
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
|
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
|
||||||
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
|
||||||
groupDescDefaultMappedModel := groupFields[30].Descriptor()
|
groupDescDefaultMappedModel := groupFields[25].Descriptor()
|
||||||
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
|
||||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
@ -963,16 +959,12 @@ func init() {
|
|||||||
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
|
||||||
usagelogDescMediaType := usagelogFields[35].Descriptor()
|
|
||||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
|
||||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
|
||||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||||
usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
|
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
|
||||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[37].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
@ -1064,14 +1056,6 @@ func init() {
|
|||||||
userDescTotpEnabled := userFields[9].Descriptor()
|
userDescTotpEnabled := userFields[9].Descriptor()
|
||||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||||
// userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
|
||||||
userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
|
|
||||||
// user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
|
||||||
user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
|
|
||||||
// userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
|
|
||||||
userDescSoraStorageUsedBytes := userFields[12].Descriptor()
|
|
||||||
// user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
|
|
||||||
user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
|
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
_ = userallowedgroupFields
|
_ = userallowedgroupFields
|
||||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
|||||||
@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
// Sora 按次计费配置(阶段 1)
|
|
||||||
field.Float("sora_image_price_360").
|
|
||||||
Optional().
|
|
||||||
Nillable().
|
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
|
||||||
field.Float("sora_image_price_540").
|
|
||||||
Optional().
|
|
||||||
Nillable().
|
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
|
||||||
field.Float("sora_video_price_per_request").
|
|
||||||
Optional().
|
|
||||||
Nillable().
|
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
|
||||||
field.Float("sora_video_price_per_request_hd").
|
|
||||||
Optional().
|
|
||||||
Nillable().
|
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
|
||||||
|
|
||||||
// Sora 存储配额
|
|
||||||
field.Int64("sora_storage_quota_bytes").
|
|
||||||
Default(0),
|
|
||||||
|
|
||||||
// Claude Code 客户端限制 (added by migration 029)
|
// Claude Code 客户端限制 (added by migration 029)
|
||||||
field.Bool("claude_code_only").
|
field.Bool("claude_code_only").
|
||||||
Default(false).
|
Default(false).
|
||||||
|
|||||||
@ -134,12 +134,6 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
MaxLen(10).
|
MaxLen(10).
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
// 媒体类型字段(sora 使用)
|
|
||||||
field.String("media_type").
|
|
||||||
MaxLen(16).
|
|
||||||
Optional().
|
|
||||||
Nillable(),
|
|
||||||
|
|
||||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||||
field.Bool("cache_ttl_overridden").
|
field.Bool("cache_ttl_overridden").
|
||||||
Default(false),
|
Default(false),
|
||||||
|
|||||||
@ -72,12 +72,6 @@ func (User) Fields() []ent.Field {
|
|||||||
field.Time("totp_enabled_at").
|
field.Time("totp_enabled_at").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|
||||||
// Sora 存储配额
|
|
||||||
field.Int64("sora_storage_quota_bytes").
|
|
||||||
Default(0),
|
|
||||||
field.Int64("sora_storage_used_bytes").
|
|
||||||
Default(0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -92,8 +92,6 @@ type UsageLog struct {
|
|||||||
ImageCount int `json:"image_count,omitempty"`
|
ImageCount int `json:"image_count,omitempty"`
|
||||||
// ImageSize holds the value of the "image_size" field.
|
// ImageSize holds the value of the "image_size" field.
|
||||||
ImageSize *string `json:"image_size,omitempty"`
|
ImageSize *string `json:"image_size,omitempty"`
|
||||||
// MediaType holds the value of the "media_type" field.
|
|
||||||
MediaType *string `json:"media_type,omitempty"`
|
|
||||||
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||||
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||||
// CreatedAt holds the value of the "created_at" field.
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
@ -187,7 +185,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@ -436,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
_m.ImageSize = new(string)
|
_m.ImageSize = new(string)
|
||||||
*_m.ImageSize = value.String
|
*_m.ImageSize = value.String
|
||||||
}
|
}
|
||||||
case usagelog.FieldMediaType:
|
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field media_type", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.MediaType = new(string)
|
|
||||||
*_m.MediaType = value.String
|
|
||||||
}
|
|
||||||
case usagelog.FieldCacheTTLOverridden:
|
case usagelog.FieldCacheTTLOverridden:
|
||||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||||
@ -649,11 +640,6 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString(*v)
|
builder.WriteString(*v)
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
if v := _m.MediaType; v != nil {
|
|
||||||
builder.WriteString("media_type=")
|
|
||||||
builder.WriteString(*v)
|
|
||||||
}
|
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("cache_ttl_overridden=")
|
builder.WriteString("cache_ttl_overridden=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@ -84,8 +84,6 @@ const (
|
|||||||
FieldImageCount = "image_count"
|
FieldImageCount = "image_count"
|
||||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||||
FieldImageSize = "image_size"
|
FieldImageSize = "image_size"
|
||||||
// FieldMediaType holds the string denoting the media_type field in the database.
|
|
||||||
FieldMediaType = "media_type"
|
|
||||||
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||||
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
@ -177,7 +175,6 @@ var Columns = []string{
|
|||||||
FieldIPAddress,
|
FieldIPAddress,
|
||||||
FieldImageCount,
|
FieldImageCount,
|
||||||
FieldImageSize,
|
FieldImageSize,
|
||||||
FieldMediaType,
|
|
||||||
FieldCacheTTLOverridden,
|
FieldCacheTTLOverridden,
|
||||||
FieldCreatedAt,
|
FieldCreatedAt,
|
||||||
}
|
}
|
||||||
@ -245,8 +242,6 @@ var (
|
|||||||
DefaultImageCount int
|
DefaultImageCount int
|
||||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
ImageSizeValidator func(string) error
|
ImageSizeValidator func(string) error
|
||||||
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
|
||||||
MediaTypeValidator func(string) error
|
|
||||||
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||||
DefaultCacheTTLOverridden bool
|
DefaultCacheTTLOverridden bool
|
||||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
@ -436,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ByMediaType orders the results by the media_type field.
|
|
||||||
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||||
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||||
|
|||||||
@ -230,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
|
|
||||||
func MediaType(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||||
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||||
@ -1905,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
|
|
||||||
func MediaTypeEQ(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
|
|
||||||
func MediaTypeNEQ(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeIn applies the In predicate on the "media_type" field.
|
|
||||||
func MediaTypeIn(vs ...string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
|
|
||||||
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeGT applies the GT predicate on the "media_type" field.
|
|
||||||
func MediaTypeGT(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
|
|
||||||
func MediaTypeGTE(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeLT applies the LT predicate on the "media_type" field.
|
|
||||||
func MediaTypeLT(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
|
|
||||||
func MediaTypeLTE(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeContains applies the Contains predicate on the "media_type" field.
|
|
||||||
func MediaTypeContains(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
|
|
||||||
func MediaTypeHasPrefix(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
|
|
||||||
func MediaTypeHasSuffix(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
|
|
||||||
func MediaTypeIsNil() predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
|
|
||||||
func MediaTypeNotNil() predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
|
|
||||||
func MediaTypeEqualFold(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
|
|
||||||
func MediaTypeContainsFold(v string) predicate.UsageLog {
|
|
||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||||
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||||
|
|||||||
@ -477,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
|
|
||||||
_c.mutation.SetMediaType(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
|
||||||
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetMediaType(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||||
_c.mutation.SetCacheTTLOverridden(v)
|
_c.mutation.SetCacheTTLOverridden(v)
|
||||||
@ -768,11 +754,6 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := _c.mutation.MediaType(); ok {
|
|
||||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||||
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||||
}
|
}
|
||||||
@ -935,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||||
_node.ImageSize = &value
|
_node.ImageSize = &value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.MediaType(); ok {
|
|
||||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
|
||||||
_node.MediaType = &value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
_node.CacheTTLOverridden = value
|
_node.CacheTTLOverridden = value
|
||||||
@ -1702,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
|
|
||||||
u.Set(usagelog.FieldMediaType, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
|
||||||
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
|
|
||||||
u.SetExcluded(usagelog.FieldMediaType)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
|
|
||||||
u.SetNull(usagelog.FieldMediaType)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||||
@ -2498,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.SetMediaType(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
|
||||||
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.UpdateMediaType()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.ClearMediaType()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@ -3465,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.SetMediaType(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
|
|
||||||
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.UpdateMediaType()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
|
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
|
||||||
s.ClearMediaType()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@ -739,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
|
|
||||||
_u.mutation.SetMediaType(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
|
||||||
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetMediaType(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
|
|
||||||
_u.mutation.ClearMediaType()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||||
_u.mutation.SetCacheTTLOverridden(v)
|
_u.mutation.SetCacheTTLOverridden(v)
|
||||||
@ -912,11 +892,6 @@ func (_u *UsageLogUpdate) check() error {
|
|||||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := _u.mutation.MediaType(); ok {
|
|
||||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||||
}
|
}
|
||||||
@ -1124,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.ImageSizeCleared() {
|
if _u.mutation.ImageSizeCleared() {
|
||||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.MediaType(); ok {
|
|
||||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.MediaTypeCleared() {
|
|
||||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
@ -2005,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMediaType sets the "media_type" field.
|
|
||||||
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
|
|
||||||
_u.mutation.SetMediaType(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
|
|
||||||
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetMediaType(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearMediaType clears the value of the "media_type" field.
|
|
||||||
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
|
|
||||||
_u.mutation.ClearMediaType()
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||||
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||||
_u.mutation.SetCacheTTLOverridden(v)
|
_u.mutation.SetCacheTTLOverridden(v)
|
||||||
@ -2191,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := _u.mutation.MediaType(); ok {
|
|
||||||
if err := usagelog.MediaTypeValidator(v); err != nil {
|
|
||||||
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||||
}
|
}
|
||||||
@ -2420,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if _u.mutation.ImageSizeCleared() {
|
if _u.mutation.ImageSizeCleared() {
|
||||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.MediaType(); ok {
|
|
||||||
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.MediaTypeCleared() {
|
|
||||||
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,10 +45,6 @@ type User struct {
|
|||||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
|
||||||
// SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
|
|
||||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
|
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the UserQuery when eager-loading is set.
|
// The values are being populated by the UserQuery when eager-loading is set.
|
||||||
Edges UserEdges `json:"edges"`
|
Edges UserEdges `json:"edges"`
|
||||||
@ -181,7 +177,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case user.FieldBalance:
|
case user.FieldBalance:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
|
case user.FieldID, user.FieldConcurrency:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@ -295,18 +291,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
_m.TotpEnabledAt = new(time.Time)
|
_m.TotpEnabledAt = new(time.Time)
|
||||||
*_m.TotpEnabledAt = value.Time
|
*_m.TotpEnabledAt = value.Time
|
||||||
}
|
}
|
||||||
case user.FieldSoraStorageQuotaBytes:
|
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraStorageQuotaBytes = value.Int64
|
|
||||||
}
|
|
||||||
case user.FieldSoraStorageUsedBytes:
|
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SoraStorageUsedBytes = value.Int64
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@ -440,12 +424,6 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString("totp_enabled_at=")
|
builder.WriteString("totp_enabled_at=")
|
||||||
builder.WriteString(v.Format(time.ANSIC))
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("sora_storage_quota_bytes=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("sora_storage_used_bytes=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
|
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,10 +43,6 @@ const (
|
|||||||
FieldTotpEnabled = "totp_enabled"
|
FieldTotpEnabled = "totp_enabled"
|
||||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||||
FieldTotpEnabledAt = "totp_enabled_at"
|
FieldTotpEnabledAt = "totp_enabled_at"
|
||||||
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
|
||||||
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
|
||||||
// FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
|
|
||||||
FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
|
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@ -156,8 +152,6 @@ var Columns = []string{
|
|||||||
FieldTotpSecretEncrypted,
|
FieldTotpSecretEncrypted,
|
||||||
FieldTotpEnabled,
|
FieldTotpEnabled,
|
||||||
FieldTotpEnabledAt,
|
FieldTotpEnabledAt,
|
||||||
FieldSoraStorageQuotaBytes,
|
|
||||||
FieldSoraStorageUsedBytes,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -214,10 +208,6 @@ var (
|
|||||||
DefaultNotes string
|
DefaultNotes string
|
||||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||||
DefaultTotpEnabled bool
|
DefaultTotpEnabled bool
|
||||||
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
|
||||||
DefaultSoraStorageQuotaBytes int64
|
|
||||||
// DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
|
|
||||||
DefaultSoraStorageUsedBytes int64
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the User queries.
|
// OrderOption defines the ordering options for the User queries.
|
||||||
@ -298,16 +288,6 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
|
||||||
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
|
|
||||||
func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@ -125,16 +125,6 @@ func TotpEnabledAt(v time.Time) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
|
||||||
func SoraStorageQuotaBytes(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
|
|
||||||
func SoraStorageUsedBytes(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.User {
|
func CreatedAtEQ(v time.Time) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@ -870,86 +860,6 @@ func TotpEnabledAtNotNil() predicate.User {
|
|||||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesEQ(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesGT(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesGTE(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesLT(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
|
||||||
func SoraStorageQuotaBytesLTE(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesEQ(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesNEQ(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesGT(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesGTE(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesLT(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
|
|
||||||
func SoraStorageUsedBytesLTE(v int64) predicate.User {
|
|
||||||
return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.User {
|
func HasAPIKeys() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@ -210,34 +210,6 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
|
|
||||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
|
|
||||||
_c.mutation.SetSoraStorageUsedBytes(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
|
||||||
func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSoraStorageUsedBytes(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -452,14 +424,6 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultTotpEnabled
|
v := user.DefaultTotpEnabled
|
||||||
_c.mutation.SetTotpEnabled(v)
|
_c.mutation.SetTotpEnabled(v)
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
|
||||||
v := user.DefaultSoraStorageQuotaBytes
|
|
||||||
_c.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
|
||||||
v := user.DefaultSoraStorageUsedBytes
|
|
||||||
_c.mutation.SetSoraStorageUsedBytes(v)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,12 +487,6 @@ func (_c *UserCreate) check() error {
|
|||||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
|
||||||
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
|
|
||||||
}
|
|
||||||
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
|
||||||
return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -612,14 +570,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
_node.TotpEnabledAt = &value
|
_node.TotpEnabledAt = &value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
_node.SoraStorageQuotaBytes = value
|
|
||||||
}
|
|
||||||
if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
|
||||||
_node.SoraStorageUsedBytes = value
|
|
||||||
}
|
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@ -1006,42 +956,6 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
|
|
||||||
u.Set(user.FieldSoraStorageQuotaBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
|
|
||||||
u.SetExcluded(user.FieldSoraStorageQuotaBytes)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
|
|
||||||
u.Add(user.FieldSoraStorageQuotaBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
|
|
||||||
u.Set(user.FieldSoraStorageUsedBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
|
|
||||||
u.SetExcluded(user.FieldSoraStorageUsedBytes)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
|
|
||||||
u.Add(user.FieldSoraStorageUsedBytes, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@ -1304,48 +1218,6 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.AddSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateSoraStorageQuotaBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetSoraStorageUsedBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.AddSoraStorageUsedBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateSoraStorageUsedBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@ -1774,48 +1646,6 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
|
||||||
func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.AddSoraStorageQuotaBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateSoraStorageQuotaBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.SetSoraStorageUsedBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
|
||||||
func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.AddSoraStorageUsedBytes(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
|
||||||
func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
|
|
||||||
return u.Update(func(s *UserUpsert) {
|
|
||||||
s.UpdateSoraStorageUsedBytes()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@ -242,48 +242,6 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
|
|
||||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
|
||||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
|
|
||||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
|
|
||||||
_u.mutation.ResetSoraStorageUsedBytes()
|
|
||||||
_u.mutation.SetSoraStorageUsedBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageUsedBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
|
||||||
func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
|
|
||||||
_u.mutation.AddSoraStorageUsedBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -751,18 +709,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
|
||||||
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@ -1406,48 +1352,6 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
|
||||||
_u.mutation.ResetSoraStorageQuotaBytes()
|
|
||||||
_u.mutation.SetSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageQuotaBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
|
||||||
func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
|
||||||
_u.mutation.AddSoraStorageQuotaBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
|
||||||
func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
|
||||||
_u.mutation.ResetSoraStorageUsedBytes()
|
|
||||||
_u.mutation.SetSoraStorageUsedBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
|
||||||
func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSoraStorageUsedBytes(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
|
||||||
func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
|
||||||
_u.mutation.AddSoraStorageUsedBytes(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -1945,18 +1849,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
|
||||||
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
|
||||||
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
|
||||||
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@ -77,7 +77,6 @@ type Config struct {
|
|||||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
Sora SoraConfig `mapstructure:"sora"`
|
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
@ -197,8 +196,6 @@ type TokenRefreshConfig struct {
|
|||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
// 重试退避基础时间(秒)
|
// 重试退避基础时间(秒)
|
||||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||||
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
|
||||||
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingConfig struct {
|
type PricingConfig struct {
|
||||||
@ -303,59 +300,6 @@ type ConcurrencyConfig struct {
|
|||||||
PingInterval int `mapstructure:"ping_interval"`
|
PingInterval int `mapstructure:"ping_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraConfig 直连 Sora 配置
|
|
||||||
type SoraConfig struct {
|
|
||||||
Client SoraClientConfig `mapstructure:"client"`
|
|
||||||
Storage SoraStorageConfig `mapstructure:"storage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraClientConfig 直连 Sora 客户端配置
|
|
||||||
type SoraClientConfig struct {
|
|
||||||
BaseURL string `mapstructure:"base_url"`
|
|
||||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
|
||||||
MaxRetries int `mapstructure:"max_retries"`
|
|
||||||
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
|
||||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
|
||||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
|
||||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
|
||||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
|
||||||
Debug bool `mapstructure:"debug"`
|
|
||||||
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
|
||||||
Headers map[string]string `mapstructure:"headers"`
|
|
||||||
UserAgent string `mapstructure:"user_agent"`
|
|
||||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
|
||||||
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
|
||||||
type SoraCurlCFFISidecarConfig struct {
|
|
||||||
Enabled bool `mapstructure:"enabled"`
|
|
||||||
BaseURL string `mapstructure:"base_url"`
|
|
||||||
Impersonate string `mapstructure:"impersonate"`
|
|
||||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
|
||||||
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
|
||||||
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageConfig 媒体存储配置
|
|
||||||
type SoraStorageConfig struct {
|
|
||||||
Type string `mapstructure:"type"`
|
|
||||||
LocalPath string `mapstructure:"local_path"`
|
|
||||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
|
||||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
|
||||||
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
|
||||||
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
|
||||||
Debug bool `mapstructure:"debug"`
|
|
||||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStorageCleanupConfig 媒体清理配置
|
|
||||||
type SoraStorageCleanupConfig struct {
|
|
||||||
Enabled bool `mapstructure:"enabled"`
|
|
||||||
Schedule string `mapstructure:"schedule"`
|
|
||||||
RetentionDays int `mapstructure:"retention_days"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GatewayConfig API网关相关配置
|
// GatewayConfig API网关相关配置
|
||||||
type GatewayConfig struct {
|
type GatewayConfig struct {
|
||||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||||
@ -424,24 +368,6 @@ type GatewayConfig struct {
|
|||||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||||
|
|
||||||
// Sora 专用配置
|
|
||||||
// SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
|
|
||||||
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
|
|
||||||
// SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
|
|
||||||
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
|
|
||||||
// SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
|
|
||||||
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
|
|
||||||
// SoraStreamMode: stream 强制策略(force/error)
|
|
||||||
SoraStreamMode string `mapstructure:"sora_stream_mode"`
|
|
||||||
// SoraModelFilters: 模型列表过滤配置
|
|
||||||
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
|
|
||||||
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
|
|
||||||
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
|
|
||||||
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
|
|
||||||
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
|
|
||||||
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
|
|
||||||
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
|
|
||||||
|
|
||||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||||
@ -639,12 +565,6 @@ type GatewayUsageRecordConfig struct {
|
|||||||
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraModelFiltersConfig Sora 模型过滤配置
|
|
||||||
type SoraModelFiltersConfig struct {
|
|
||||||
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
|
|
||||||
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLSFingerprintConfig TLS指纹伪装配置
|
// TLSFingerprintConfig TLS指纹伪装配置
|
||||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||||
type TLSFingerprintConfig struct {
|
type TLSFingerprintConfig struct {
|
||||||
@ -1402,13 +1322,6 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
|
||||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
|
||||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
|
||||||
viper.SetDefault("gateway.sora_stream_mode", "force")
|
|
||||||
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
|
|
||||||
viper.SetDefault("gateway.sora_media_require_api_key", true)
|
|
||||||
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
|
|
||||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||||
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
||||||
@ -1465,45 +1378,12 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||||
viper.SetDefault("concurrency.ping_interval", 10)
|
viper.SetDefault("concurrency.ping_interval", 10)
|
||||||
|
|
||||||
// Sora 直连配置
|
|
||||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
|
||||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
|
||||||
viper.SetDefault("sora.client.max_retries", 3)
|
|
||||||
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
|
||||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
|
||||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
|
||||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
|
||||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
|
||||||
viper.SetDefault("sora.client.debug", false)
|
|
||||||
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
|
||||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
|
||||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
|
||||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
|
||||||
|
|
||||||
viper.SetDefault("sora.storage.type", "local")
|
|
||||||
viper.SetDefault("sora.storage.local_path", "")
|
|
||||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
|
||||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
|
||||||
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
|
||||||
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
|
||||||
viper.SetDefault("sora.storage.debug", false)
|
|
||||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
|
||||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
|
||||||
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
|
|
||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
viper.SetDefault("token_refresh.enabled", true)
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
|
||||||
|
|
||||||
// Gemini OAuth - configure via environment variables or config file
|
// Gemini OAuth - configure via environment variables or config file
|
||||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||||
@ -1879,86 +1759,6 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||||
}
|
}
|
||||||
if c.Gateway.SoraMaxBodySize < 0 {
|
|
||||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
|
|
||||||
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
|
|
||||||
switch mode {
|
|
||||||
case "force", "error":
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Sora.Client.TimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.MaxRetries < 0 {
|
|
||||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
|
||||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.RecentTaskLimit < 0 {
|
|
||||||
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
|
||||||
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
|
||||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
|
||||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
|
||||||
}
|
|
||||||
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
|
||||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
|
||||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
|
||||||
}
|
|
||||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
|
||||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
|
||||||
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora.Storage.Cleanup.Enabled {
|
|
||||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
|
||||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
|
|
||||||
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
|
|
||||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
|
|
||||||
return fmt.Errorf("sora.storage.type must be 'local'")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||||
switch c.Gateway.ConnectionPoolIsolation {
|
switch c.Gateway.ConnectionPoolIsolation {
|
||||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||||
|
|||||||
@ -1554,94 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
|
||||||
resetViperWithJWTSecret(t)
|
|
||||||
|
|
||||||
cfg, err := Load()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
|
||||||
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
|
||||||
}
|
|
||||||
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
|
||||||
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
|
||||||
}
|
|
||||||
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
|
||||||
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
|
||||||
}
|
|
||||||
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
|
||||||
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
|
||||||
}
|
|
||||||
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
|
||||||
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
|
||||||
}
|
|
||||||
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
|
||||||
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
|
||||||
resetViperWithJWTSecret(t)
|
|
||||||
|
|
||||||
cfg, err := Load()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
|
||||||
err = cfg.Validate()
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
|
||||||
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
|
||||||
resetViperWithJWTSecret(t)
|
|
||||||
|
|
||||||
cfg, err := Load()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
|
||||||
err = cfg.Validate()
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
|
||||||
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
|
||||||
resetViperWithJWTSecret(t)
|
|
||||||
|
|
||||||
cfg, err := Load()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
|
||||||
err = cfg.Validate()
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
|
||||||
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
|
||||||
resetViperWithJWTSecret(t)
|
|
||||||
|
|
||||||
cfg, err := Load()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
|
||||||
err = cfg.Validate()
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
|
||||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||||
resetViperWithJWTSecret(t)
|
resetViperWithJWTSecret(t)
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
|
|||||||
@ -22,7 +22,6 @@ const (
|
|||||||
PlatformOpenAI = "openai"
|
PlatformOpenAI = "openai"
|
||||||
PlatformGemini = "gemini"
|
PlatformGemini = "gemini"
|
||||||
PlatformAntigravity = "antigravity"
|
PlatformAntigravity = "antigravity"
|
||||||
PlatformSora = "sora"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
|
|||||||
@ -567,15 +567,15 @@ func defaultProxyName(name string) string {
|
|||||||
|
|
||||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
// Only applies to OpenAI OAuth accounts. Skips expired token errors silently.
|
||||||
// Existing credential values are never overwritten — only missing fields are filled.
|
// Existing credential values are never overwritten — only missing fields are filled.
|
||||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||||
if item.Credentials == nil {
|
if item.Credentials == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Only enrich OpenAI/Sora OAuth accounts
|
// Only enrich OpenAI OAuth accounts
|
||||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
if platform != service.PlatformOpenAI {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||||
|
|||||||
@ -1875,12 +1875,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Sora accounts
|
|
||||||
if account.Platform == service.PlatformSora {
|
|
||||||
response.Success(c, service.DefaultSoraModels(nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle Claude/Anthropic accounts
|
// Handle Claude/Anthropic accounts
|
||||||
// For OAuth and Setup-Token accounts: return default models
|
// For OAuth and Setup-Token accounts: return default models
|
||||||
if account.IsOAuth() {
|
if account.IsOAuth() {
|
||||||
|
|||||||
@ -380,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
|
|||||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
|
|||||||
type CreateGroupRequest struct {
|
type CreateGroupRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
IsExclusive bool `json:"is_exclusive"`
|
IsExclusive bool `json:"is_exclusive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
@ -95,10 +95,6 @@ type CreateGroupRequest struct {
|
|||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
@ -108,8 +104,6 @@ type CreateGroupRequest struct {
|
|||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||||
@ -123,7 +117,7 @@ type CreateGroupRequest struct {
|
|||||||
type UpdateGroupRequest struct {
|
type UpdateGroupRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
IsExclusive *bool `json:"is_exclusive"`
|
IsExclusive *bool `json:"is_exclusive"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
@ -135,10 +129,6 @@ type UpdateGroupRequest struct {
|
|||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
|
||||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
@ -148,8 +138,6 @@ type UpdateGroupRequest struct {
|
|||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
|
||||||
RequireOAuthOnly *bool `json:"require_oauth_only"`
|
RequireOAuthOnly *bool `json:"require_oauth_only"`
|
||||||
@ -258,10 +246,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
SoraImagePrice360: req.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: req.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
@ -269,7 +253,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||||
RequirePrivacySet: req.RequirePrivacySet,
|
RequirePrivacySet: req.RequirePrivacySet,
|
||||||
@ -313,10 +296,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
SoraImagePrice360: req.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: req.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
@ -324,7 +303,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
RequireOAuthOnly: req.RequireOAuthOnly,
|
RequireOAuthOnly: req.RequireOAuthOnly,
|
||||||
RequirePrivacySet: req.RequirePrivacySet,
|
RequirePrivacySet: req.RequirePrivacySet,
|
||||||
|
|||||||
@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func oauthPlatformFromPath(c *gin.Context) string {
|
func oauthPlatformFromPath(c *gin.Context) string {
|
||||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
|
||||||
return service.PlatformSora
|
|
||||||
}
|
|
||||||
return service.PlatformOpenAI
|
return service.PlatformOpenAI
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct {
|
|||||||
|
|
||||||
// RefreshToken refreshes an OpenAI OAuth token
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
// POST /api/v1/admin/openai/refresh-token
|
// POST /api/v1/admin/openai/refresh-token
|
||||||
// POST /api/v1/admin/sora/rt2at
|
|
||||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
var req OpenAIRefreshTokenRequest
|
var req OpenAIRefreshTokenRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
response.Success(c, tokenInfo)
|
response.Success(c, tokenInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||||
// POST /api/v1/admin/sora/st2at
|
|
||||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
|
||||||
var req struct {
|
|
||||||
SessionToken string `json:"session_token"`
|
|
||||||
ST string `json:"st"`
|
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
|
||||||
if sessionToken == "" {
|
|
||||||
sessionToken = strings.TrimSpace(req.ST)
|
|
||||||
}
|
|
||||||
if sessionToken == "" {
|
|
||||||
response.BadRequest(c, "session_token is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, tokenInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
|
||||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
|
||||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
|||||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||||
// POST /api/v1/admin/openai/create-from-oauth
|
// POST /api/v1/admin/openai/create-from-oauth
|
||||||
// POST /api/v1/admin/sora/create-from-oauth
|
|
||||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
name = tokenInfo.Email
|
name = tokenInfo.Email
|
||||||
}
|
}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
if platform == service.PlatformSora {
|
name = "OpenAI OAuth Account"
|
||||||
name = "Sora OAuth Account"
|
|
||||||
} else {
|
|
||||||
name = "OpenAI OAuth Account"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create account
|
// Create account
|
||||||
|
|||||||
@ -41,17 +41,15 @@ type SettingHandler struct {
|
|||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
turnstileService *service.TurnstileService
|
turnstileService *service.TurnstileService
|
||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
soraS3Storage *service.SoraS3Storage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingHandler 创建系统设置处理器
|
// NewSettingHandler 创建系统设置处理器
|
||||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||||
return &SettingHandler{
|
return &SettingHandler{
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
turnstileService: turnstileService,
|
turnstileService: turnstileService,
|
||||||
opsService: opsService,
|
opsService: opsService,
|
||||||
soraS3Storage: soraS3Storage,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,7 +106,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: settings.HideCcsImportButton,
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
|
||||||
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
||||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
@ -177,7 +174,6 @@ type UpdateSettingsRequest struct {
|
|||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
|
||||||
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
||||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||||
|
|
||||||
@ -566,7 +562,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: req.HideCcsImportButton,
|
HideCcsImportButton: req.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||||
PurchaseSubscriptionURL: purchaseURL,
|
PurchaseSubscriptionURL: purchaseURL,
|
||||||
SoraClientEnabled: req.SoraClientEnabled,
|
|
||||||
CustomMenuItems: customMenuJSON,
|
CustomMenuItems: customMenuJSON,
|
||||||
CustomEndpoints: customEndpointsJSON,
|
CustomEndpoints: customEndpointsJSON,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
@ -676,7 +671,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
|
||||||
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
||||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||||
@ -1207,384 +1201,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
|
||||||
if settings == nil {
|
|
||||||
return dto.SoraS3Settings{}
|
|
||||||
}
|
|
||||||
return dto.SoraS3Settings{
|
|
||||||
Enabled: settings.Enabled,
|
|
||||||
Endpoint: settings.Endpoint,
|
|
||||||
Region: settings.Region,
|
|
||||||
Bucket: settings.Bucket,
|
|
||||||
AccessKeyID: settings.AccessKeyID,
|
|
||||||
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
|
||||||
Prefix: settings.Prefix,
|
|
||||||
ForcePathStyle: settings.ForcePathStyle,
|
|
||||||
CDNURL: settings.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
|
||||||
return dto.SoraS3Profile{
|
|
||||||
ProfileID: profile.ProfileID,
|
|
||||||
Name: profile.Name,
|
|
||||||
IsActive: profile.IsActive,
|
|
||||||
Enabled: profile.Enabled,
|
|
||||||
Endpoint: profile.Endpoint,
|
|
||||||
Region: profile.Region,
|
|
||||||
Bucket: profile.Bucket,
|
|
||||||
AccessKeyID: profile.AccessKeyID,
|
|
||||||
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
|
||||||
Prefix: profile.Prefix,
|
|
||||||
ForcePathStyle: profile.ForcePathStyle,
|
|
||||||
CDNURL: profile.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
|
||||||
UpdatedAt: profile.UpdatedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
|
||||||
if !enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(endpoint) == "" {
|
|
||||||
return fmt.Errorf("S3 Endpoint is required when enabled")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(bucket) == "" {
|
|
||||||
return fmt.Errorf("S3 Bucket is required when enabled")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(accessKeyID) == "" {
|
|
||||||
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
|
||||||
for idx := range items {
|
|
||||||
if items[idx].ProfileID == profileID {
|
|
||||||
return &items[idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
|
||||||
// GET /api/v1/admin/settings/sora-s3
|
|
||||||
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
|
||||||
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, toSoraS3SettingsDTO(settings))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListSoraS3Profiles 获取 Sora S3 多配置
|
|
||||||
// GET /api/v1/admin/settings/sora-s3/profiles
|
|
||||||
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
|
||||||
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
|
||||||
for idx := range result.Items {
|
|
||||||
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
|
||||||
}
|
|
||||||
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
|
||||||
ActiveProfileID: result.ActiveProfileID,
|
|
||||||
Items: items,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
|
||||||
type UpdateSoraS3SettingsRequest struct {
|
|
||||||
ProfileID string `json:"profile_id"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"secret_access_key"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CreateSoraS3ProfileRequest struct {
|
|
||||||
ProfileID string `json:"profile_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
SetActive bool `json:"set_active"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"secret_access_key"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpdateSoraS3ProfileRequest struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"secret_access_key"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
|
||||||
// POST /api/v1/admin/settings/sora-s3/profiles
|
|
||||||
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
|
||||||
var req CreateSoraS3ProfileRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.DefaultStorageQuotaBytes < 0 {
|
|
||||||
req.DefaultStorageQuotaBytes = 0
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Name) == "" {
|
|
||||||
response.BadRequest(c, "Name is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.ProfileID) == "" {
|
|
||||||
response.BadRequest(c, "Profile ID is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
|
||||||
ProfileID: req.ProfileID,
|
|
||||||
Name: req.Name,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
Endpoint: req.Endpoint,
|
|
||||||
Region: req.Region,
|
|
||||||
Bucket: req.Bucket,
|
|
||||||
AccessKeyID: req.AccessKeyID,
|
|
||||||
SecretAccessKey: req.SecretAccessKey,
|
|
||||||
Prefix: req.Prefix,
|
|
||||||
ForcePathStyle: req.ForcePathStyle,
|
|
||||||
CDNURL: req.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
|
||||||
}, req.SetActive)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, toSoraS3ProfileDTO(*created))
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
|
||||||
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
|
||||||
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
|
||||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
|
||||||
if profileID == "" {
|
|
||||||
response.BadRequest(c, "Profile ID is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req UpdateSoraS3ProfileRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.DefaultStorageQuotaBytes < 0 {
|
|
||||||
req.DefaultStorageQuotaBytes = 0
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Name) == "" {
|
|
||||||
response.BadRequest(c, "Name is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
|
||||||
if existing == nil {
|
|
||||||
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
|
||||||
Name: req.Name,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
Endpoint: req.Endpoint,
|
|
||||||
Region: req.Region,
|
|
||||||
Bucket: req.Bucket,
|
|
||||||
AccessKeyID: req.AccessKeyID,
|
|
||||||
SecretAccessKey: req.SecretAccessKey,
|
|
||||||
Prefix: req.Prefix,
|
|
||||||
ForcePathStyle: req.ForcePathStyle,
|
|
||||||
CDNURL: req.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
|
||||||
})
|
|
||||||
if updateErr != nil {
|
|
||||||
response.ErrorFrom(c, updateErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, toSoraS3ProfileDTO(*updated))
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
|
||||||
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
|
||||||
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
|
||||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
|
||||||
if profileID == "" {
|
|
||||||
response.BadRequest(c, "Profile ID is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, gin.H{"deleted": true})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
|
||||||
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
|
||||||
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
|
||||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
|
||||||
if profileID == "" {
|
|
||||||
response.BadRequest(c, "Profile ID is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, toSoraS3ProfileDTO(*active))
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
|
||||||
// PUT /api/v1/admin/settings/sora-s3
|
|
||||||
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
|
||||||
var req UpdateSoraS3SettingsRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.DefaultStorageQuotaBytes < 0 {
|
|
||||||
req.DefaultStorageQuotaBytes = 0
|
|
||||||
}
|
|
||||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
settings := &service.SoraS3Settings{
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
Endpoint: req.Endpoint,
|
|
||||||
Region: req.Region,
|
|
||||||
Bucket: req.Bucket,
|
|
||||||
AccessKeyID: req.AccessKeyID,
|
|
||||||
SecretAccessKey: req.SecretAccessKey,
|
|
||||||
Prefix: req.Prefix,
|
|
||||||
ForcePathStyle: req.ForcePathStyle,
|
|
||||||
CDNURL: req.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
|
||||||
}
|
|
||||||
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
|
||||||
// POST /api/v1/admin/settings/sora-s3/test
|
|
||||||
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
|
||||||
if h.soraS3Storage == nil {
|
|
||||||
response.Error(c, 500, "S3 存储服务未初始化")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req UpdateSoraS3SettingsRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !req.Enabled {
|
|
||||||
response.BadRequest(c, "S3 未启用,无法测试连接")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.SecretAccessKey == "" {
|
|
||||||
if req.ProfileID != "" {
|
|
||||||
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
|
||||||
if err == nil {
|
|
||||||
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
|
||||||
if profile != nil {
|
|
||||||
req.SecretAccessKey = profile.SecretAccessKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if req.SecretAccessKey == "" {
|
|
||||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
|
||||||
if err == nil {
|
|
||||||
req.SecretAccessKey = existing.SecretAccessKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
testCfg := &service.SoraS3Settings{
|
|
||||||
Enabled: true,
|
|
||||||
Endpoint: req.Endpoint,
|
|
||||||
Region: req.Region,
|
|
||||||
Bucket: req.Bucket,
|
|
||||||
AccessKeyID: req.AccessKeyID,
|
|
||||||
SecretAccessKey: req.SecretAccessKey,
|
|
||||||
Prefix: req.Prefix,
|
|
||||||
ForcePathStyle: req.ForcePathStyle,
|
|
||||||
CDNURL: req.CDNURL,
|
|
||||||
}
|
|
||||||
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
|
||||||
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRectifierSettings 获取请求整流器配置
|
// GetRectifierSettings 获取请求整流器配置
|
||||||
// GET /api/v1/admin/settings/rectifier
|
// GET /api/v1/admin/settings/rectifier
|
||||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||||
|
|||||||
@ -34,14 +34,13 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
|||||||
|
|
||||||
// CreateUserRequest represents admin create user request
|
// CreateUserRequest represents admin create user request
|
||||||
type CreateUserRequest struct {
|
type CreateUserRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
Password string `json:"password" binding:"required,min=6"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
AllowedGroups []int64 `json:"allowed_groups"`
|
AllowedGroups []int64 `json:"allowed_groups"`
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUserRequest represents admin update user request
|
// UpdateUserRequest represents admin update user request
|
||||||
@ -57,8 +56,7 @@ type UpdateUserRequest struct {
|
|||||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateBalanceRequest represents balance update request
|
// UpdateBalanceRequest represents balance update request
|
||||||
@ -182,14 +180,13 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -216,16 +213,15 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
GroupRates: req.GroupRates,
|
GroupRates: req.GroupRates,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &AdminUser{
|
return &AdminUser{
|
||||||
User: *base,
|
User: *base,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
GroupRates: u.GroupRates,
|
GroupRates: u.GroupRates,
|
||||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
|
||||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,14 +170,9 @@ func groupFromServiceBase(g *service.Group) Group {
|
|||||||
ImagePrice1K: g.ImagePrice1K,
|
ImagePrice1K: g.ImagePrice1K,
|
||||||
ImagePrice2K: g.ImagePrice2K,
|
ImagePrice2K: g.ImagePrice2K,
|
||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
SoraImagePrice360: g.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: g.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
|
||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||||
RequirePrivacySet: g.RequirePrivacySet,
|
RequirePrivacySet: g.RequirePrivacySet,
|
||||||
|
|||||||
@ -61,7 +61,6 @@ type SystemSettings struct {
|
|||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
|
||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||||
|
|
||||||
@ -128,49 +127,10 @@ type PublicSettings struct {
|
|||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
|
||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
|
||||||
type SoraS3Settings struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
|
||||||
type SoraS3Profile struct {
|
|
||||||
ProfileID string `json:"profile_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
IsActive bool `json:"is_active"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
UpdatedAt string `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
|
||||||
type ListSoraS3ProfilesResponse struct {
|
|
||||||
ActiveProfileID string `json:"active_profile_id"`
|
|
||||||
Items []SoraS3Profile `json:"items"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
type OverloadCooldownSettings struct {
|
type OverloadCooldownSettings struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@ -26,9 +26,7 @@ type AdminUser struct {
|
|||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
// map[groupID]rateMultiplier
|
// map[groupID]rateMultiplier
|
||||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
@ -84,21 +82,12 @@ type Group struct {
|
|||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
|
||||||
// Sora 按次计费配置
|
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
|
||||||
|
|
||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
// 无效请求兜底分组
|
// 无效请求兜底分组
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
|
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
|
||||||
|
|
||||||
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
|
||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ const (
|
|||||||
// ──────────────────────────────────────────────────────────
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
// NormalizeInboundEndpoint maps a raw request path (which may carry
|
||||||
// prefixes like /antigravity, /openai, /sora) to its canonical form.
|
// prefixes like /antigravity, /openai) to its canonical form.
|
||||||
//
|
//
|
||||||
// "/antigravity/v1/messages" → "/v1/messages"
|
// "/antigravity/v1/messages" → "/v1/messages"
|
||||||
// "/v1/chat/completions" → "/v1/chat/completions"
|
// "/v1/chat/completions" → "/v1/chat/completions"
|
||||||
@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string {
|
|||||||
// such as /v1/responses/compact preserved from the raw URL).
|
// such as /v1/responses/compact preserved from the raw URL).
|
||||||
// - Anthropic → /v1/messages
|
// - Anthropic → /v1/messages
|
||||||
// - Gemini → /v1beta/models
|
// - Gemini → /v1beta/models
|
||||||
// - Sora → /v1/chat/completions
|
// - Antigravity → /v1/messages (Claude) or gemini (Gemini)
|
||||||
// - Antigravity routes may target either Claude or Gemini, so the
|
// - Antigravity routes may target either Claude or Gemini, so the
|
||||||
// inbound endpoint is used to distinguish.
|
// inbound endpoint is used to distinguish.
|
||||||
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||||
@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
|||||||
case service.PlatformGemini:
|
case service.PlatformGemini:
|
||||||
return EndpointGeminiModels
|
return EndpointGeminiModels
|
||||||
|
|
||||||
case service.PlatformSora:
|
|
||||||
return EndpointChatCompletions
|
|
||||||
|
|
||||||
case service.PlatformAntigravity:
|
case service.PlatformAntigravity:
|
||||||
// Antigravity accounts serve both Claude and Gemini.
|
// Antigravity accounts serve both Claude and Gemini.
|
||||||
if inbound == EndpointGeminiModels {
|
if inbound == EndpointGeminiModels {
|
||||||
|
|||||||
@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
|||||||
{"/v1/responses", EndpointResponses},
|
{"/v1/responses", EndpointResponses},
|
||||||
{"/v1beta/models", EndpointGeminiModels},
|
{"/v1beta/models", EndpointGeminiModels},
|
||||||
|
|
||||||
// Prefixed paths (antigravity, openai, sora).
|
// Prefixed paths (antigravity, openai).
|
||||||
{"/antigravity/v1/messages", EndpointMessages},
|
{"/antigravity/v1/messages", EndpointMessages},
|
||||||
{"/openai/v1/responses", EndpointResponses},
|
{"/openai/v1/responses", EndpointResponses},
|
||||||
{"/openai/v1/responses/compact", EndpointResponses},
|
{"/openai/v1/responses/compact", EndpointResponses},
|
||||||
{"/sora/v1/chat/completions", EndpointChatCompletions},
|
|
||||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||||
|
|
||||||
// Gin route patterns with wildcards.
|
// Gin route patterns with wildcards.
|
||||||
@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
|||||||
// Gemini.
|
// Gemini.
|
||||||
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
|
||||||
|
|
||||||
// Sora.
|
|
||||||
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
|
|
||||||
|
|
||||||
// OpenAI — always /v1/responses.
|
// OpenAI — always /v1/responses.
|
||||||
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
|
||||||
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
|
||||||
|
|||||||
@ -859,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
|||||||
platform = forcedPlatform
|
platform = forcedPlatform
|
||||||
}
|
}
|
||||||
|
|
||||||
if platform == service.PlatformSora {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"object": "list",
|
|
||||||
"data": service.DefaultSoraModels(h.cfg),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get available models from account configurations (without platform filter)
|
// Get available models from account configurations (without platform filter)
|
||||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||||
|
|
||||||
|
|||||||
@ -45,8 +45,6 @@ type Handlers struct {
|
|||||||
Admin *AdminHandlers
|
Admin *AdminHandlers
|
||||||
Gateway *GatewayHandler
|
Gateway *GatewayHandler
|
||||||
OpenAIGateway *OpenAIGatewayHandler
|
OpenAIGateway *OpenAIGatewayHandler
|
||||||
SoraGateway *SoraGatewayHandler
|
|
||||||
SoraClient *SoraClientHandler
|
|
||||||
Setting *SettingHandler
|
Setting *SettingHandler
|
||||||
Totp *TotpHandler
|
Totp *TotpHandler
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
|
||||||
BackendModeEnabled: settings.BackendModeEnabled,
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,979 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// 上游模型缓存 TTL
|
|
||||||
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
|
||||||
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
|
||||||
)
|
|
||||||
|
|
||||||
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
|
||||||
type SoraClientHandler struct {
|
|
||||||
genService *service.SoraGenerationService
|
|
||||||
quotaService *service.SoraQuotaService
|
|
||||||
s3Storage *service.SoraS3Storage
|
|
||||||
soraGatewayService *service.SoraGatewayService
|
|
||||||
gatewayService *service.GatewayService
|
|
||||||
mediaStorage *service.SoraMediaStorage
|
|
||||||
apiKeyService *service.APIKeyService
|
|
||||||
|
|
||||||
// 上游模型缓存
|
|
||||||
modelCacheMu sync.RWMutex
|
|
||||||
cachedFamilies []service.SoraModelFamily
|
|
||||||
modelCacheTime time.Time
|
|
||||||
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
|
||||||
func NewSoraClientHandler(
|
|
||||||
genService *service.SoraGenerationService,
|
|
||||||
quotaService *service.SoraQuotaService,
|
|
||||||
s3Storage *service.SoraS3Storage,
|
|
||||||
soraGatewayService *service.SoraGatewayService,
|
|
||||||
gatewayService *service.GatewayService,
|
|
||||||
mediaStorage *service.SoraMediaStorage,
|
|
||||||
apiKeyService *service.APIKeyService,
|
|
||||||
) *SoraClientHandler {
|
|
||||||
return &SoraClientHandler{
|
|
||||||
genService: genService,
|
|
||||||
quotaService: quotaService,
|
|
||||||
s3Storage: s3Storage,
|
|
||||||
soraGatewayService: soraGatewayService,
|
|
||||||
gatewayService: gatewayService,
|
|
||||||
mediaStorage: mediaStorage,
|
|
||||||
apiKeyService: apiKeyService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateRequest 生成请求。
|
|
||||||
type GenerateRequest struct {
|
|
||||||
Model string `json:"model" binding:"required"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
MediaType string `json:"media_type"` // video / image,默认 video
|
|
||||||
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
|
||||||
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
|
||||||
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
|
||||||
// POST /api/v1/sora/generate
|
|
||||||
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req GenerateRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.MediaType == "" {
|
|
||||||
req.MediaType = "video"
|
|
||||||
}
|
|
||||||
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
|
||||||
|
|
||||||
// 并发数检查(最多 3 个)
|
|
||||||
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if activeCount >= 3 {
|
|
||||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
|
||||||
if h.quotaService != nil {
|
|
||||||
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
|
||||||
var quotaErr *service.QuotaExceededError
|
|
||||||
if errors.As(err, "aErr) {
|
|
||||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Error(c, http.StatusForbidden, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取 API Key ID 和 Group ID
|
|
||||||
var apiKeyID *int64
|
|
||||||
var groupID *int64
|
|
||||||
|
|
||||||
if req.APIKeyID != nil && h.apiKeyService != nil {
|
|
||||||
// 前端传递了 api_key_id,需要校验
|
|
||||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if apiKey.UserID != userID {
|
|
||||||
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if apiKey.Status != service.StatusAPIKeyActive {
|
|
||||||
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
apiKeyID = &apiKey.ID
|
|
||||||
groupID = apiKey.GroupID
|
|
||||||
} else if id, ok := c.Get("api_key_id"); ok {
|
|
||||||
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
|
||||||
if v, ok := id.(int64); ok {
|
|
||||||
apiKeyID = &v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
|
||||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 启动后台异步生成 goroutine
|
|
||||||
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
|
||||||
"generation_id": gen.ID,
|
|
||||||
"status": gen.Status,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// processGeneration 后台异步执行 Sora 生成任务。
|
|
||||||
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
|
||||||
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// 标记为生成中
|
|
||||||
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
|
||||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"handler.sora_client",
|
|
||||||
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
|
||||||
genID,
|
|
||||||
userID,
|
|
||||||
groupIDForLog(groupID),
|
|
||||||
model,
|
|
||||||
mediaType,
|
|
||||||
videoCount,
|
|
||||||
strings.TrimSpace(imageInput) != "",
|
|
||||||
len(strings.TrimSpace(prompt)),
|
|
||||||
)
|
|
||||||
|
|
||||||
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
|
||||||
if groupID == nil {
|
|
||||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.gatewayService == nil {
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 选择 Sora 账号
|
|
||||||
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"handler.sora_client",
|
|
||||||
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
|
||||||
genID,
|
|
||||||
userID,
|
|
||||||
groupIDForLog(groupID),
|
|
||||||
model,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"handler.sora_client",
|
|
||||||
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
|
||||||
genID,
|
|
||||||
userID,
|
|
||||||
groupIDForLog(groupID),
|
|
||||||
model,
|
|
||||||
account.ID,
|
|
||||||
account.Name,
|
|
||||||
account.Platform,
|
|
||||||
account.Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
// 构建 chat completions 请求体(非流式)
|
|
||||||
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
|
||||||
|
|
||||||
if h.soraGatewayService == nil {
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
|
||||||
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
|
||||||
|
|
||||||
// 调用 Forward(非流式)
|
|
||||||
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"handler.sora_client",
|
|
||||||
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
|
||||||
genID,
|
|
||||||
account.ID,
|
|
||||||
model,
|
|
||||||
recorder.Code,
|
|
||||||
trimForLog(recorder.Body.String(), 400),
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
// 检查是否已取消
|
|
||||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
|
||||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
|
||||||
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
|
||||||
if mediaURL == "" {
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"handler.sora_client",
|
|
||||||
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
|
||||||
genID,
|
|
||||||
account.ID,
|
|
||||||
model,
|
|
||||||
recorder.Code,
|
|
||||||
trimForLog(recorder.Body.String(), 400),
|
|
||||||
)
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查任务是否已被取消
|
|
||||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
|
||||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
|
||||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
|
||||||
|
|
||||||
usageAdded := false
|
|
||||||
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
|
||||||
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
|
||||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
|
||||||
var quotaErr *service.QuotaExceededError
|
|
||||||
if errors.As(err, "aErr) {
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
usageAdded = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
|
||||||
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
|
||||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
|
||||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
|
||||||
if usageAdded && h.quotaService != nil {
|
|
||||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 标记完成
|
|
||||||
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
|
||||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
|
||||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
|
||||||
if usageAdded && h.quotaService != nil {
|
|
||||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
|
||||||
func (h *SoraClientHandler) storeMediaWithDegradation(
|
|
||||||
ctx context.Context, userID int64, mediaType string,
|
|
||||||
mediaURL string, mediaURLs []string,
|
|
||||||
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
|
||||||
urls := mediaURLs
|
|
||||||
if len(urls) == 0 {
|
|
||||||
urls = []string{mediaURL}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 第一层:尝试 S3
|
|
||||||
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
|
||||||
keys := make([]string, 0, len(urls))
|
|
||||||
var totalSize int64
|
|
||||||
allOK := true
|
|
||||||
for _, u := range urls {
|
|
||||||
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
|
||||||
allOK = false
|
|
||||||
// 清理已上传的文件
|
|
||||||
if len(keys) > 0 {
|
|
||||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
keys = append(keys, key)
|
|
||||||
totalSize += size
|
|
||||||
}
|
|
||||||
if allOK && len(keys) > 0 {
|
|
||||||
accessURLs := make([]string, 0, len(keys))
|
|
||||||
for _, key := range keys {
|
|
||||||
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
|
||||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
|
||||||
allOK = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
accessURLs = append(accessURLs, accessURL)
|
|
||||||
}
|
|
||||||
if allOK && len(accessURLs) > 0 {
|
|
||||||
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 第二层:尝试本地存储
|
|
||||||
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
|
||||||
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
|
||||||
if err == nil && len(storedPaths) > 0 {
|
|
||||||
firstPath := storedPaths[0]
|
|
||||||
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
|
||||||
if sizeErr != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
|
||||||
}
|
|
||||||
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 第三层:保留上游临时 URL
|
|
||||||
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
|
||||||
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
|
||||||
body := map[string]any{
|
|
||||||
"model": model,
|
|
||||||
"messages": []map[string]string{
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
},
|
|
||||||
"stream": false,
|
|
||||||
}
|
|
||||||
if imageInput != "" {
|
|
||||||
body["image_input"] = imageInput
|
|
||||||
}
|
|
||||||
if videoCount > 1 {
|
|
||||||
body["video_count"] = videoCount
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(body)
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeVideoCount(mediaType string, videoCount int) int {
|
|
||||||
if mediaType != "video" {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
if videoCount <= 0 {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
if videoCount > 3 {
|
|
||||||
return 3
|
|
||||||
}
|
|
||||||
return videoCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
|
||||||
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
|
||||||
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
|
||||||
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
|
||||||
// 优先从 ForwardResult 获取(OAuth 路径)
|
|
||||||
if result != nil && result.MediaURL != "" {
|
|
||||||
// 尝试从响应体获取完整 URL 列表
|
|
||||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
|
||||||
return urls[0], urls
|
|
||||||
}
|
|
||||||
return result.MediaURL, []string{result.MediaURL}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从响应体解析(APIKey 路径)
|
|
||||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
|
||||||
return urls[0], urls
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
|
||||||
func parseMediaURLsFromBody(body []byte) []string {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var resp map[string]any
|
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优先 media_urls(多图数组)
|
|
||||||
if rawURLs, ok := resp["media_urls"]; ok {
|
|
||||||
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
|
||||||
urls := make([]string, 0, len(arr))
|
|
||||||
for _, item := range arr {
|
|
||||||
if s, ok := item.(string); ok && s != "" {
|
|
||||||
urls = append(urls, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(urls) > 0 {
|
|
||||||
return urls
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 回退到 media_url(单个 URL)
|
|
||||||
if url, ok := resp["media_url"].(string); ok && url != "" {
|
|
||||||
return []string{url}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListGenerations 查询生成记录列表。
|
|
||||||
// GET /api/v1/sora/generations
|
|
||||||
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
|
||||||
|
|
||||||
params := service.SoraGenerationListParams{
|
|
||||||
UserID: userID,
|
|
||||||
Status: c.Query("status"),
|
|
||||||
StorageType: c.Query("storage_type"),
|
|
||||||
MediaType: c.Query("media_type"),
|
|
||||||
Page: page,
|
|
||||||
PageSize: pageSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
gens, total, err := h.genService.List(c.Request.Context(), params)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 为 S3 记录动态生成预签名 URL
|
|
||||||
for _, gen := range gens {
|
|
||||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
|
||||||
"data": gens,
|
|
||||||
"total": total,
|
|
||||||
"page": page,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGeneration 查询生成记录详情。
|
|
||||||
// GET /api/v1/sora/generations/:id
|
|
||||||
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusNotFound, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
|
||||||
response.Success(c, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteGeneration 删除生成记录。
|
|
||||||
// DELETE /api/v1/sora/generations/:id
|
|
||||||
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusNotFound, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
|
||||||
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
|
||||||
paths := gen.MediaURLs
|
|
||||||
if len(paths) == 0 && gen.MediaURL != "" {
|
|
||||||
paths = []string{gen.MediaURL}
|
|
||||||
}
|
|
||||||
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
|
||||||
response.Error(c, http.StatusNotFound, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{"message": "已删除"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetQuota 查询用户存储配额。
|
|
||||||
// GET /api/v1/sora/quota
|
|
||||||
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.quotaService == nil {
|
|
||||||
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, quota)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CancelGeneration 取消生成任务。
|
|
||||||
// POST /api/v1/sora/generations/:id/cancel
|
|
||||||
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 权限校验
|
|
||||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusNotFound, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = gen
|
|
||||||
|
|
||||||
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
|
||||||
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
|
||||||
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Error(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{"message": "已取消"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveToStorage 手动保存 upstream 记录到 S3。
|
|
||||||
// POST /api/v1/sora/generations/:id/save
|
|
||||||
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
|
||||||
userID := getUserIDFromContext(c)
|
|
||||||
if userID == 0 {
|
|
||||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusNotFound, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if gen.StorageType != service.SoraStorageTypeUpstream {
|
|
||||||
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gen.MediaURL == "" {
|
|
||||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
|
||||||
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceURLs := gen.MediaURLs
|
|
||||||
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
|
||||||
sourceURLs = []string{gen.MediaURL}
|
|
||||||
}
|
|
||||||
if len(sourceURLs) == 0 {
|
|
||||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
uploadedKeys := make([]string, 0, len(sourceURLs))
|
|
||||||
accessURLs := make([]string, 0, len(sourceURLs))
|
|
||||||
var totalSize int64
|
|
||||||
|
|
||||||
for _, sourceURL := range sourceURLs {
|
|
||||||
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
|
||||||
if uploadErr != nil {
|
|
||||||
if len(uploadedKeys) > 0 {
|
|
||||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
|
||||||
}
|
|
||||||
var upstreamErr *service.UpstreamDownloadError
|
|
||||||
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
|
||||||
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
|
||||||
if err != nil {
|
|
||||||
uploadedKeys = append(uploadedKeys, objectKey)
|
|
||||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
|
||||||
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
uploadedKeys = append(uploadedKeys, objectKey)
|
|
||||||
accessURLs = append(accessURLs, accessURL)
|
|
||||||
totalSize += fileSize
|
|
||||||
}
|
|
||||||
|
|
||||||
usageAdded := false
|
|
||||||
if totalSize > 0 && h.quotaService != nil {
|
|
||||||
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
|
||||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
|
||||||
var quotaErr *service.QuotaExceededError
|
|
||||||
if errors.As(err, "aErr) {
|
|
||||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
usageAdded = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.genService.UpdateStorageForCompleted(
|
|
||||||
c.Request.Context(),
|
|
||||||
id,
|
|
||||||
accessURLs[0],
|
|
||||||
accessURLs,
|
|
||||||
service.SoraStorageTypeS3,
|
|
||||||
uploadedKeys,
|
|
||||||
totalSize,
|
|
||||||
); err != nil {
|
|
||||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
|
||||||
if usageAdded && h.quotaService != nil {
|
|
||||||
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
|
||||||
}
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
|
||||||
"message": "已保存到 S3",
|
|
||||||
"object_key": uploadedKeys[0],
|
|
||||||
"object_keys": uploadedKeys,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStorageStatus 返回存储状态。
|
|
||||||
// GET /api/v1/sora/storage-status
|
|
||||||
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
|
||||||
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
|
||||||
s3Healthy := false
|
|
||||||
if s3Enabled {
|
|
||||||
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
|
||||||
}
|
|
||||||
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
|
||||||
response.Success(c, gin.H{
|
|
||||||
"s3_enabled": s3Enabled,
|
|
||||||
"s3_healthy": s3Healthy,
|
|
||||||
"local_enabled": localEnabled,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
|
||||||
switch storageType {
|
|
||||||
case service.SoraStorageTypeS3:
|
|
||||||
if h.s3Storage != nil && len(s3Keys) > 0 {
|
|
||||||
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case service.SoraStorageTypeLocal:
|
|
||||||
if h.mediaStorage != nil && len(localPaths) > 0 {
|
|
||||||
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
|
||||||
func getUserIDFromContext(c *gin.Context) int64 {
|
|
||||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
|
||||||
return subject.UserID
|
|
||||||
}
|
|
||||||
|
|
||||||
if id, ok := c.Get("user_id"); ok {
|
|
||||||
switch v := id.(type) {
|
|
||||||
case int64:
|
|
||||||
return v
|
|
||||||
case float64:
|
|
||||||
return int64(v)
|
|
||||||
case string:
|
|
||||||
n, _ := strconv.ParseInt(v, 10, 64)
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 尝试从 JWT claims 获取
|
|
||||||
if id, ok := c.Get("userID"); ok {
|
|
||||||
if v, ok := id.(int64); ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func groupIDForLog(groupID *int64) int64 {
|
|
||||||
if groupID == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return *groupID
|
|
||||||
}
|
|
||||||
|
|
||||||
func trimForLog(raw string, maxLen int) string {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if maxLen <= 0 || len(trimmed) <= maxLen {
|
|
||||||
return trimmed
|
|
||||||
}
|
|
||||||
return trimmed[:maxLen] + "...(truncated)"
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModels 获取可用 Sora 模型家族列表。
|
|
||||||
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
|
||||||
// GET /api/v1/sora/models
|
|
||||||
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
|
||||||
families := h.getModelFamilies(c.Request.Context())
|
|
||||||
response.Success(c, families)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getModelFamilies 获取模型家族列表(带缓存)。
|
|
||||||
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
|
||||||
// 读锁检查缓存
|
|
||||||
h.modelCacheMu.RLock()
|
|
||||||
ttl := modelCacheTTL
|
|
||||||
if !h.modelCacheUpstream {
|
|
||||||
ttl = modelCacheFailedTTL
|
|
||||||
}
|
|
||||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
|
||||||
families := h.cachedFamilies
|
|
||||||
h.modelCacheMu.RUnlock()
|
|
||||||
return families
|
|
||||||
}
|
|
||||||
h.modelCacheMu.RUnlock()
|
|
||||||
|
|
||||||
// 写锁更新缓存
|
|
||||||
h.modelCacheMu.Lock()
|
|
||||||
defer h.modelCacheMu.Unlock()
|
|
||||||
|
|
||||||
// double-check
|
|
||||||
ttl = modelCacheTTL
|
|
||||||
if !h.modelCacheUpstream {
|
|
||||||
ttl = modelCacheFailedTTL
|
|
||||||
}
|
|
||||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
|
||||||
return h.cachedFamilies
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试从上游获取
|
|
||||||
families, err := h.fetchUpstreamModels(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
|
||||||
families = service.BuildSoraModelFamilies()
|
|
||||||
h.cachedFamilies = families
|
|
||||||
h.modelCacheTime = time.Now()
|
|
||||||
h.modelCacheUpstream = false
|
|
||||||
return families
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
|
||||||
h.cachedFamilies = families
|
|
||||||
h.modelCacheTime = time.Now()
|
|
||||||
h.modelCacheUpstream = true
|
|
||||||
return families
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
|
||||||
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
|
||||||
if h.gatewayService == nil {
|
|
||||||
return nil, fmt.Errorf("gatewayService 未初始化")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置 ForcePlatform 用于 Sora 账号选择
|
|
||||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
|
||||||
|
|
||||||
// 选择一个 Sora 账号
|
|
||||||
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 仅支持 API Key 类型账号
|
|
||||||
if account.Type != service.AccountTypeAPIKey {
|
|
||||||
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := account.GetCredential("api_key")
|
|
||||||
if apiKey == "" {
|
|
||||||
return nil, fmt.Errorf("账号缺少 api_key")
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := account.GetBaseURL()
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, fmt.Errorf("账号缺少 base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建上游模型列表请求
|
|
||||||
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
|
||||||
|
|
||||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("请求上游失败: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 OpenAI 格式的模型列表
|
|
||||||
var modelsResp struct {
|
|
||||||
Data []struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelsResp.Data) == 0 {
|
|
||||||
return nil, fmt.Errorf("上游返回空模型列表")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取模型 ID
|
|
||||||
modelIDs := make([]string, 0, len(modelsResp.Data))
|
|
||||||
for _, m := range modelsResp.Data {
|
|
||||||
modelIDs = append(modelIDs, m.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转换为模型家族
|
|
||||||
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
|
||||||
if len(families) == 0 {
|
|
||||||
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
|
||||||
}
|
|
||||||
|
|
||||||
return families, nil
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,697 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SoraGatewayHandler handles Sora chat completions requests
|
|
||||||
//
|
|
||||||
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
|
|
||||||
type SoraGatewayHandler struct {
|
|
||||||
gatewayService *service.GatewayService
|
|
||||||
soraGatewayService *service.SoraGatewayService
|
|
||||||
billingCacheService *service.BillingCacheService
|
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
|
||||||
concurrencyHelper *ConcurrencyHelper
|
|
||||||
maxAccountSwitches int
|
|
||||||
streamMode string
|
|
||||||
soraTLSEnabled bool
|
|
||||||
soraMediaSigningKey string
|
|
||||||
soraMediaRoot string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
|
||||||
func NewSoraGatewayHandler(
|
|
||||||
gatewayService *service.GatewayService,
|
|
||||||
soraGatewayService *service.SoraGatewayService,
|
|
||||||
concurrencyService *service.ConcurrencyService,
|
|
||||||
billingCacheService *service.BillingCacheService,
|
|
||||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
|
||||||
cfg *config.Config,
|
|
||||||
) *SoraGatewayHandler {
|
|
||||||
pingInterval := time.Duration(0)
|
|
||||||
maxAccountSwitches := 3
|
|
||||||
streamMode := "force"
|
|
||||||
soraTLSEnabled := true
|
|
||||||
signKey := ""
|
|
||||||
mediaRoot := "/app/data/sora"
|
|
||||||
if cfg != nil {
|
|
||||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
|
||||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
|
||||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
|
||||||
}
|
|
||||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
|
||||||
streamMode = mode
|
|
||||||
}
|
|
||||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
|
||||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
|
||||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
|
||||||
mediaRoot = root
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &SoraGatewayHandler{
|
|
||||||
gatewayService: gatewayService,
|
|
||||||
soraGatewayService: soraGatewayService,
|
|
||||||
billingCacheService: billingCacheService,
|
|
||||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
|
||||||
streamMode: strings.ToLower(streamMode),
|
|
||||||
soraTLSEnabled: soraTLSEnabled,
|
|
||||||
soraMediaSigningKey: signKey,
|
|
||||||
soraMediaRoot: mediaRoot,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletions handles Sora /v1/chat/completions endpoint
|
|
||||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|
||||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
|
||||||
if !ok {
|
|
||||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
|
||||||
if !ok {
|
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqLog := requestLogger(
|
|
||||||
c,
|
|
||||||
"handler.sora_gateway.chat_completions",
|
|
||||||
zap.Int64("user_id", subject.UserID),
|
|
||||||
zap.Int64("api_key_id", apiKey.ID),
|
|
||||||
zap.Any("group_id", apiKey.GroupID),
|
|
||||||
)
|
|
||||||
|
|
||||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
|
||||||
if err != nil {
|
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(body) == 0 {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
|
||||||
|
|
||||||
// 校验请求体 JSON 合法性
|
|
||||||
if !gjson.ValidBytes(body) {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
|
||||||
modelResult := gjson.GetBytes(body, "model")
|
|
||||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqModel := modelResult.String()
|
|
||||||
|
|
||||||
msgsResult := gjson.GetBytes(body, "messages")
|
|
||||||
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
|
||||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
|
|
||||||
if !clientStream {
|
|
||||||
if h.streamMode == "error" {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
body, err = sjson.SetBytes(body, "stream", true)
|
|
||||||
if err != nil {
|
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
|
|
||||||
|
|
||||||
platform := ""
|
|
||||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
|
||||||
platform = forced
|
|
||||||
} else if apiKey.Group != nil {
|
|
||||||
platform = apiKey.Group.Platform
|
|
||||||
}
|
|
||||||
if platform != service.PlatformSora {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
streamStarted := false
|
|
||||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
|
||||||
|
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
|
||||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
|
||||||
waitCounted := false
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
|
|
||||||
} else if !canWait {
|
|
||||||
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
|
||||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err == nil && canWait {
|
|
||||||
waitCounted = true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if waitCounted {
|
|
||||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
|
|
||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if waitCounted {
|
|
||||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
||||||
waitCounted = false
|
|
||||||
}
|
|
||||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
|
||||||
if userReleaseFunc != nil {
|
|
||||||
defer userReleaseFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
|
||||||
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
|
|
||||||
status, code, message := billingErrorDetails(err)
|
|
||||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionHash := generateOpenAISessionHash(c, body)
|
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
lastFailoverStatus := 0
|
|
||||||
var lastFailoverBody []byte
|
|
||||||
var lastFailoverHeaders http.Header
|
|
||||||
|
|
||||||
for {
|
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("sora.account_select_failed",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
|
||||||
)
|
|
||||||
if len(failedAccountIDs) == 0 {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
|
||||||
fields := []zap.Field{
|
|
||||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
|
||||||
}
|
|
||||||
if rayID != "" {
|
|
||||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
|
||||||
}
|
|
||||||
if mitigated != "" {
|
|
||||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
|
||||||
}
|
|
||||||
if contentType != "" {
|
|
||||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
|
||||||
}
|
|
||||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account := selection.Account
|
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
|
||||||
proxyBound := account.ProxyID != nil
|
|
||||||
proxyID := int64(0)
|
|
||||||
if account.ProxyID != nil {
|
|
||||||
proxyID = *account.ProxyID
|
|
||||||
}
|
|
||||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
|
||||||
|
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
|
||||||
if !selection.Acquired {
|
|
||||||
if selection.WaitPlan == nil {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
accountWaitCounted := false
|
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
} else if !canWait {
|
|
||||||
reqLog.Info("sora.account_wait_queue_full",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
|
||||||
)
|
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err == nil && canWait {
|
|
||||||
accountWaitCounted = true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
|
||||||
c,
|
|
||||||
account.ID,
|
|
||||||
selection.WaitPlan.MaxConcurrency,
|
|
||||||
selection.WaitPlan.Timeout,
|
|
||||||
clientStream,
|
|
||||||
&streamStarted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
|
||||||
|
|
||||||
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
|
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
var failoverErr *service.UpstreamFailoverError
|
|
||||||
if errors.As(err, &failoverErr) {
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
|
||||||
lastFailoverBody = failoverErr.ResponseBody
|
|
||||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
|
||||||
fields := []zap.Field{
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
}
|
|
||||||
if rayID != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
|
||||||
}
|
|
||||||
if mitigated != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
|
||||||
}
|
|
||||||
if contentType != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
|
||||||
}
|
|
||||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
|
||||||
lastFailoverBody = failoverErr.ResponseBody
|
|
||||||
switchCount++
|
|
||||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
|
||||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
|
||||||
fields := []zap.Field{
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.String("upstream_error_code", upstreamErrCode),
|
|
||||||
zap.String("upstream_error_message", upstreamErrMsg),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
}
|
|
||||||
if rayID != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
|
||||||
}
|
|
||||||
if mitigated != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
|
||||||
}
|
|
||||||
if contentType != "" {
|
|
||||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
|
||||||
}
|
|
||||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reqLog.Error("sora.forward_failed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
|
||||||
clientIP := ip.GetClientIP(c)
|
|
||||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
|
||||||
inboundEndpoint := GetInboundEndpoint(c)
|
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
APIKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
InboundEndpoint: inboundEndpoint,
|
|
||||||
UpstreamEndpoint: upstreamEndpoint,
|
|
||||||
UserAgent: userAgent,
|
|
||||||
IPAddress: clientIP,
|
|
||||||
RequestPayloadHash: requestPayloadHash,
|
|
||||||
}); err != nil {
|
|
||||||
logger.L().With(
|
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
|
||||||
zap.Int64("user_id", subject.UserID),
|
|
||||||
zap.Int64("api_key_id", apiKey.ID),
|
|
||||||
zap.Any("group_id", apiKey.GroupID),
|
|
||||||
zap.String("model", reqModel),
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
).Error("sora.record_usage_failed", zap.Error(err))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
reqLog.Debug("sora.request_completed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int64("proxy_id", proxyID),
|
|
||||||
zap.Bool("proxy_bound", proxyBound),
|
|
||||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
|
||||||
if c == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
|
||||||
if sessionID == "" {
|
|
||||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
|
||||||
}
|
|
||||||
if sessionID == "" && len(body) > 0 {
|
|
||||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
|
||||||
}
|
|
||||||
if sessionID == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
hash := sha256.Sum256([]byte(sessionID))
|
|
||||||
return hex.EncodeToString(hash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
|
||||||
if task == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if h.usageRecordWorkerPool != nil {
|
|
||||||
h.usageRecordWorkerPool.Submit(task)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
defer func() {
|
|
||||||
if recovered := recover(); recovered != nil {
|
|
||||||
logger.L().With(
|
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
|
||||||
zap.Any("panic", recovered),
|
|
||||||
).Error("sora.usage_record_task_panic_recovered")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
task(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
|
||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
|
||||||
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
|
|
||||||
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
|
|
||||||
|
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
|
||||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
|
||||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
|
||||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
|
||||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
|
||||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
|
||||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
|
||||||
}
|
|
||||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
|
||||||
switch statusCode {
|
|
||||||
case 401, 403, 404, 500, 502, 503, 504:
|
|
||||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
|
||||||
case 429:
|
|
||||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch statusCode {
|
|
||||||
case 401:
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
|
||||||
case 403:
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
|
||||||
case 404:
|
|
||||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
|
||||||
}
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
|
||||||
case 429:
|
|
||||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
|
||||||
case 529:
|
|
||||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
|
||||||
case 500, 502, 503, 504:
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
|
||||||
default:
|
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
|
||||||
if headers == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return headers.Clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
|
||||||
if headers != nil {
|
|
||||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
|
||||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
|
||||||
if contentType == "" {
|
|
||||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
|
||||||
return rayID, mitigated, contentType
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
|
||||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
|
||||||
message = strings.TrimSpace(message)
|
|
||||||
if message == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
|
||||||
lower := strings.ToLower(message)
|
|
||||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
|
||||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
|
||||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
|
||||||
if streamStarted {
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if ok {
|
|
||||||
errorData := map[string]any{
|
|
||||||
"error": map[string]string{
|
|
||||||
"type": errType,
|
|
||||||
"message": message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(errorData)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.errorResponse(c, status, errType, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
|
||||||
c.JSON(status, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": errType,
|
|
||||||
"message": message,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaProxy serves local Sora media files.
|
|
||||||
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
|
||||||
h.proxySoraMedia(c, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MediaProxySigned serves local Sora media files with signature verification.
|
|
||||||
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
|
||||||
h.proxySoraMedia(c, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
|
||||||
rawPath := c.Param("filepath")
|
|
||||||
if rawPath == "" {
|
|
||||||
c.Status(http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cleaned := path.Clean(rawPath)
|
|
||||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
|
||||||
c.Status(http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
if requireSignature {
|
|
||||||
if h.soraMediaSigningKey == "" {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "api_error",
|
|
||||||
"message": "Sora 媒体签名未配置",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
expiresStr := strings.TrimSpace(query.Get("expires"))
|
|
||||||
signature := strings.TrimSpace(query.Get("sig"))
|
|
||||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
|
||||||
if err != nil || expires <= time.Now().Unix() {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "authentication_error",
|
|
||||||
"message": "Sora 媒体签名已过期",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
query.Del("sig")
|
|
||||||
query.Del("expires")
|
|
||||||
signingQuery := query.Encode()
|
|
||||||
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "authentication_error",
|
|
||||||
"message": "Sora 媒体签名无效",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(h.soraMediaRoot) == "" {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "api_error",
|
|
||||||
"message": "Sora 媒体目录未配置",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
relative := strings.TrimPrefix(cleaned, "/")
|
|
||||||
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
|
|
||||||
if _, err := os.Stat(localPath); err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
c.Status(http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Status(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.File(localPath)
|
|
||||||
}
|
|
||||||
@ -1,728 +0,0 @@
|
|||||||
//go:build unit
|
|
||||||
|
|
||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 编译期接口断言
|
|
||||||
var _ service.SoraClient = (*stubSoraClient)(nil)
|
|
||||||
var _ service.AccountRepository = (*stubAccountRepo)(nil)
|
|
||||||
var _ service.GroupRepository = (*stubGroupRepo)(nil)
|
|
||||||
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
|
||||||
|
|
||||||
type stubSoraClient struct {
|
|
||||||
imageURLs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubSoraClient) Enabled() bool { return true }
|
|
||||||
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
|
|
||||||
return "upload", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
|
|
||||||
return "task-image", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
|
||||||
return "task-video", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
|
||||||
return "task-video", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
|
||||||
return "cameo-1", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
|
||||||
return &service.SoraCameoStatus{
|
|
||||||
Status: "finalized",
|
|
||||||
StatusMessage: "Completed",
|
|
||||||
DisplayNameHint: "Character",
|
|
||||||
UsernameHint: "user.character",
|
|
||||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
|
||||||
return []byte("avatar"), nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
|
||||||
return "asset-pointer", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
|
||||||
return "character-1", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
|
||||||
return "s_post", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
|
||||||
return "https://example.com/no-watermark.mp4", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
|
||||||
return "enhanced prompt", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
|
||||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
|
|
||||||
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubAccountRepo struct {
|
|
||||||
accounts map[int64]*service.Account
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
|
|
||||||
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
|
||||||
if acc, ok := r.accounts[id]; ok {
|
|
||||||
return acc, nil
|
|
||||||
}
|
|
||||||
return nil, service.ErrAccountNotFound
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
|
||||||
var result []*service.Account
|
|
||||||
for _, id := range ids {
|
|
||||||
if acc, ok := r.accounts[id]; ok {
|
|
||||||
result = append(result, acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
|
||||||
_, ok := r.accounts[id]
|
|
||||||
return ok, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
|
||||||
return map[string]int64{}, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
|
|
||||||
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
|
|
||||||
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
|
||||||
return r.listSchedulableByPlatform(platform), nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
|
|
||||||
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
|
||||||
return r.listSchedulable(), nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
|
||||||
return r.listSchedulable(), nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
|
||||||
return r.listSchedulableByPlatform(platform), nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
|
||||||
return r.listSchedulableByPlatform(platform), nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
|
||||||
var result []service.Account
|
|
||||||
for _, acc := range r.accounts {
|
|
||||||
for _, platform := range platforms {
|
|
||||||
if acc.Platform == platform && acc.IsSchedulable() {
|
|
||||||
result = append(result, *acc)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
|
||||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
|
||||||
return r.ListSchedulableByPlatform(ctx, platform)
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
|
||||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
|
||||||
var result []service.Account
|
|
||||||
for _, acc := range r.accounts {
|
|
||||||
if acc.IsSchedulable() {
|
|
||||||
result = append(result, *acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
|
|
||||||
var result []service.Account
|
|
||||||
for _, acc := range r.accounts {
|
|
||||||
if acc.Platform == platform && acc.IsSchedulable() {
|
|
||||||
result = append(result, *acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubGroupRepo struct {
|
|
||||||
group *service.Group
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
|
|
||||||
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
|
||||||
return r.group, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
|
||||||
return r.group, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
|
|
||||||
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
|
|
||||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
|
||||||
return 0, 0, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubUsageLogRepo struct{}
|
|
||||||
|
|
||||||
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
|
|
||||||
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
|
||||||
return []usagestats.EndpointStat{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
|
||||||
return []usagestats.EndpointStat{}, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
cfg := &config.Config{
|
|
||||||
RunMode: config.RunModeSimple,
|
|
||||||
Gateway: config.GatewayConfig{
|
|
||||||
SoraStreamMode: "force",
|
|
||||||
MaxAccountSwitches: 1,
|
|
||||||
Scheduling: config.GatewaySchedulingConfig{
|
|
||||||
LoadBatchEnabled: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
BaseURL: "https://sora.test",
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
|
|
||||||
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
|
|
||||||
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
|
|
||||||
groupRepo := &stubGroupRepo{group: group}
|
|
||||||
|
|
||||||
usageLogRepo := &stubUsageLogRepo{}
|
|
||||||
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
|
||||||
billingService := service.NewBillingService(cfg, nil)
|
|
||||||
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
|
|
||||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
billingCacheService.Stop()
|
|
||||||
})
|
|
||||||
|
|
||||||
gatewayService := service.NewGatewayService(
|
|
||||||
accountRepo,
|
|
||||||
groupRepo,
|
|
||||||
usageLogRepo,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
testutil.StubGatewayCache{},
|
|
||||||
cfg,
|
|
||||||
nil,
|
|
||||||
concurrencyService,
|
|
||||||
billingService,
|
|
||||||
nil,
|
|
||||||
billingCacheService,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
deferredService,
|
|
||||||
nil,
|
|
||||||
testutil.StubSessionLimitCache{},
|
|
||||||
nil, // rpmCache
|
|
||||||
nil, // digestStore
|
|
||||||
nil, // settingService
|
|
||||||
nil, // tlsFPProfileService
|
|
||||||
nil, // channelService
|
|
||||||
nil, // resolver
|
|
||||||
)
|
|
||||||
|
|
||||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
|
||||||
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
|
||||||
|
|
||||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(rec)
|
|
||||||
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
|
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
apiKey := &service.APIKey{
|
|
||||||
ID: 1,
|
|
||||||
UserID: 1,
|
|
||||||
Status: service.StatusActive,
|
|
||||||
GroupID: &group.ID,
|
|
||||||
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
|
|
||||||
Group: group,
|
|
||||||
}
|
|
||||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
|
||||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
|
|
||||||
|
|
||||||
handler.ChatCompletions(c)
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
|
||||||
var resp map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
|
||||||
require.NotEmpty(t, resp["media_url"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
|
|
||||||
func TestSoraHandler_StreamForcing(t *testing.T) {
|
|
||||||
// 测试 1:stream=false 时 sjson 强制修改为 true
|
|
||||||
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
|
|
||||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
|
||||||
require.False(t, clientStream)
|
|
||||||
newBody, err := sjson.SetBytes(body, "stream", true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
|
|
||||||
|
|
||||||
// 测试 2:stream=true 时不修改
|
|
||||||
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
|
|
||||||
require.True(t, gjson.GetBytes(body2, "stream").Bool())
|
|
||||||
|
|
||||||
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
|
|
||||||
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
|
|
||||||
require.False(t, gjson.GetBytes(body3, "stream").Bool())
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
|
|
||||||
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
|
||||||
// model 缺失
|
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
|
||||||
modelResult := gjson.GetBytes(body, "model")
|
|
||||||
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
|
|
||||||
|
|
||||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
|
||||||
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
|
|
||||||
modelResult1b := gjson.GetBytes(body1b, "model")
|
|
||||||
require.True(t, modelResult1b.Exists())
|
|
||||||
require.NotEqual(t, gjson.String, modelResult1b.Type)
|
|
||||||
|
|
||||||
// messages 缺失
|
|
||||||
body2 := []byte(`{"model":"sora"}`)
|
|
||||||
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
|
|
||||||
|
|
||||||
// messages 不是 JSON 数组(字符串)
|
|
||||||
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
|
||||||
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
|
|
||||||
|
|
||||||
// messages 是对象而非数组 → IsArray 返回 false
|
|
||||||
body4 := []byte(`{"model":"sora","messages":{}}`)
|
|
||||||
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
|
|
||||||
|
|
||||||
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
|
|
||||||
body5 := []byte(`{"model":"sora","messages":[]}`)
|
|
||||||
msgsResult := gjson.GetBytes(body5, "messages")
|
|
||||||
require.True(t, msgsResult.IsArray())
|
|
||||||
require.Equal(t, 0, len(msgsResult.Array()))
|
|
||||||
|
|
||||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
|
||||||
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
|
||||||
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
// 从 body 提取 prompt_cache_key
|
|
||||||
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
|
||||||
|
|
||||||
hash := generateOpenAISessionHash(c, body)
|
|
||||||
require.NotEmpty(t, hash)
|
|
||||||
|
|
||||||
// 无 prompt_cache_key 且无 header → 空 hash
|
|
||||||
body2 := []byte(`{"model":"sora"}`)
|
|
||||||
hash2 := generateOpenAISessionHash(c, body2)
|
|
||||||
require.Empty(t, hash2)
|
|
||||||
|
|
||||||
// header 优先于 body
|
|
||||||
c.Request.Header.Set("session_id", "from-header")
|
|
||||||
hash3 := generateOpenAISessionHash(c, body)
|
|
||||||
require.NotEmpty(t, hash3)
|
|
||||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
errType string
|
|
||||||
message string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "包含双引号",
|
|
||||||
errType: "upstream_error",
|
|
||||||
message: `upstream returned "invalid" payload`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含换行和制表符",
|
|
||||||
errType: "rate_limit_error",
|
|
||||||
message: "line1\nline2\ttab",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含反斜杠",
|
|
||||||
errType: "upstream_error",
|
|
||||||
message: `path C:\Users\test\file.txt not found`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
|
||||||
|
|
||||||
body := w.Body.String()
|
|
||||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
|
||||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
|
||||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
|
||||||
require.Equal(t, "event: error", lines[0])
|
|
||||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
|
||||||
|
|
||||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
|
||||||
var parsed map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
|
||||||
|
|
||||||
errorObj, ok := parsed["error"].(map[string]any)
|
|
||||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
|
||||||
require.Equal(t, tt.errType, errorObj["type"])
|
|
||||||
require.Equal(t, tt.message, errorObj["message"])
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
|
||||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
|
||||||
|
|
||||||
body := w.Body.String()
|
|
||||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
|
||||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
|
||||||
require.Len(t, lines, 2)
|
|
||||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
|
||||||
|
|
||||||
var parsed map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
|
||||||
|
|
||||||
errorObj, ok := parsed["error"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "upstream_error", errorObj["type"])
|
|
||||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
|
||||||
headers := http.Header{}
|
|
||||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
|
||||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
|
||||||
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
|
||||||
require.Len(t, lines, 2)
|
|
||||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
|
||||||
|
|
||||||
var parsed map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
|
||||||
|
|
||||||
errorObj, ok := parsed["error"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "upstream_error", errorObj["type"])
|
|
||||||
msg, _ := errorObj["message"].(string)
|
|
||||||
require.Contains(t, msg, "Cloudflare challenge")
|
|
||||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
|
||||||
headers := http.Header{}
|
|
||||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
|
||||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
|
||||||
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
|
||||||
require.Len(t, lines, 2)
|
|
||||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
|
||||||
|
|
||||||
var parsed map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
|
||||||
|
|
||||||
errorObj, ok := parsed["error"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
|
||||||
msg, _ := errorObj["message"].(string)
|
|
||||||
require.Contains(t, msg, "Cloudflare shield")
|
|
||||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
|
||||||
headers := http.Header{}
|
|
||||||
headers.Set("cf-mitigated", "challenge")
|
|
||||||
headers.Set("content-type", "text/html")
|
|
||||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
|
||||||
|
|
||||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
|
||||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
|
||||||
require.Equal(t, "challenge", mitigated)
|
|
||||||
require.Equal(t, "text/html", contentType)
|
|
||||||
}
|
|
||||||
@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
|||||||
})
|
})
|
||||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
|
||||||
pool := newUsageRecordTestPool(t)
|
|
||||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("task not executed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
var called atomic.Bool
|
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
||||||
if _, ok := ctx.Deadline(); !ok {
|
|
||||||
t.Fatal("expected deadline in fallback context")
|
|
||||||
}
|
|
||||||
called.Store(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
require.True(t, called.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
require.NotPanics(t, func() {
|
|
||||||
h.submitUsageRecordTask(nil)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
|
||||||
h := &SoraGatewayHandler{}
|
|
||||||
var called atomic.Bool
|
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
||||||
panic("usage task panic")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
||||||
called.Store(true)
|
|
||||||
})
|
|
||||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
|
||||||
}
|
|
||||||
|
|||||||
@ -86,8 +86,6 @@ func ProvideHandlers(
|
|||||||
adminHandlers *AdminHandlers,
|
adminHandlers *AdminHandlers,
|
||||||
gatewayHandler *GatewayHandler,
|
gatewayHandler *GatewayHandler,
|
||||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||||
soraGatewayHandler *SoraGatewayHandler,
|
|
||||||
soraClientHandler *SoraClientHandler,
|
|
||||||
settingHandler *SettingHandler,
|
settingHandler *SettingHandler,
|
||||||
totpHandler *TotpHandler,
|
totpHandler *TotpHandler,
|
||||||
_ *service.IdempotencyCoordinator,
|
_ *service.IdempotencyCoordinator,
|
||||||
@ -104,8 +102,6 @@ func ProvideHandlers(
|
|||||||
Admin: adminHandlers,
|
Admin: adminHandlers,
|
||||||
Gateway: gatewayHandler,
|
Gateway: gatewayHandler,
|
||||||
OpenAIGateway: openaiGatewayHandler,
|
OpenAIGateway: openaiGatewayHandler,
|
||||||
SoraGateway: soraGatewayHandler,
|
|
||||||
SoraClient: soraClientHandler,
|
|
||||||
Setting: settingHandler,
|
Setting: settingHandler,
|
||||||
Totp: totpHandler,
|
Totp: totpHandler,
|
||||||
}
|
}
|
||||||
@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAnnouncementHandler,
|
NewAnnouncementHandler,
|
||||||
NewGatewayHandler,
|
NewGatewayHandler,
|
||||||
NewOpenAIGatewayHandler,
|
NewOpenAIGatewayHandler,
|
||||||
NewSoraGatewayHandler,
|
|
||||||
NewTotpHandler,
|
NewTotpHandler,
|
||||||
ProvideSettingHandler,
|
ProvideSettingHandler,
|
||||||
|
|
||||||
|
|||||||
@ -17,8 +17,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
|
||||||
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
|
||||||
|
|
||||||
// OAuth endpoints
|
// OAuth endpoints
|
||||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||||
@ -39,8 +37,6 @@ const (
|
|||||||
const (
|
const (
|
||||||
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
||||||
OAuthPlatformOpenAI = "openai"
|
OAuthPlatformOpenAI = "openai"
|
||||||
// OAuthPlatformSora uses Sora OAuth client.
|
|
||||||
OAuthPlatformSora = "sora"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OAuthSession stores OAuth flow state for OpenAI
|
// OAuthSession stores OAuth flow state for OpenAI
|
||||||
@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
||||||
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
|
|
||||||
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
|
|
||||||
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
||||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
return ClientID, true
|
||||||
case OAuthPlatformSora:
|
|
||||||
return ClientID, false
|
|
||||||
default:
|
|
||||||
return ClientID, true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenRequest represents the token exchange request body
|
// TokenRequest represents the token exchange request body
|
||||||
|
|||||||
@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
|
|||||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
|
|
||||||
// 但不启用 codex_cli_simplified_flow。
|
|
||||||
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
|
|
||||||
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
|
|
||||||
parsed, err := url.Parse(authURL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Parse URL failed: %v", err)
|
|
||||||
}
|
|
||||||
q := parsed.Query()
|
|
||||||
if got := q.Get("client_id"); got != ClientID {
|
|
||||||
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
|
|
||||||
}
|
|
||||||
if got := q.Get("codex_cli_simplified_flow"); got != "" {
|
|
||||||
t.Fatalf("codex flow should be empty for sora, got=%q", got)
|
|
||||||
}
|
|
||||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
|
||||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1692,20 +1692,13 @@ func itoa(v int) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
|
||||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||||
//
|
//
|
||||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
|
||||||
//
|
|
||||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
|
||||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||||
//
|
|
||||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
|
||||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||||
accounts, err := r.client.Account.Query().
|
accounts, err := r.client.Account.Query().
|
||||||
Where(
|
Where(
|
||||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
|
||||||
dbaccount.DeletedAtIsNil(),
|
dbaccount.DeletedAtIsNil(),
|
||||||
func(s *entsql.Selector) {
|
func(s *entsql.Selector) {
|
||||||
path := sqljson.Path(key)
|
path := sqljson.Path(key)
|
||||||
|
|||||||
@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldImagePrice1k,
|
group.FieldImagePrice1k,
|
||||||
group.FieldImagePrice2k,
|
group.FieldImagePrice2k,
|
||||||
group.FieldImagePrice4k,
|
group.FieldImagePrice4k,
|
||||||
group.FieldSoraImagePrice360,
|
|
||||||
group.FieldSoraImagePrice540,
|
|
||||||
group.FieldSoraVideoPricePerRequest,
|
|
||||||
group.FieldSoraVideoPricePerRequestHd,
|
|
||||||
group.FieldClaudeCodeOnly,
|
group.FieldClaudeCodeOnly,
|
||||||
group.FieldFallbackGroupID,
|
group.FieldFallbackGroupID,
|
||||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||||
@ -608,22 +604,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.User{
|
return &service.User{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
PasswordHash: u.PasswordHash,
|
PasswordHash: u.PasswordHash,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
TotpEnabled: u.TotpEnabled,
|
||||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
TotpEnabledAt: u.TotpEnabledAt,
|
||||||
TotpEnabled: u.TotpEnabled,
|
CreatedAt: u.CreatedAt,
|
||||||
TotpEnabledAt: u.TotpEnabledAt,
|
UpdatedAt: u.UpdatedAt,
|
||||||
CreatedAt: u.CreatedAt,
|
|
||||||
UpdatedAt: u.UpdatedAt,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
ImagePrice1K: g.ImagePrice1k,
|
ImagePrice1K: g.ImagePrice1k,
|
||||||
ImagePrice2K: g.ImagePrice2k,
|
ImagePrice2K: g.ImagePrice2k,
|
||||||
ImagePrice4K: g.ImagePrice4k,
|
ImagePrice4K: g.ImagePrice4k,
|
||||||
SoraImagePrice360: g.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: g.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
|
||||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
|
||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
|
|||||||
@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
|
||||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
|
||||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
|
||||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
|
||||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
|
||||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
|
||||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
|
|||||||
@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
|
|||||||
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
|
|
||||||
var seenClientIDs []string
|
|
||||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
clientID := r.PostForm.Get("client_id")
|
|
||||||
seenClientIDs = append(seenClientIDs, clientID)
|
|
||||||
if clientID == openai.SoraClientID {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
}))
|
|
||||||
|
|
||||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
|
|
||||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
|
||||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
|
||||||
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||||
const customClientID = "custom-client-id"
|
const customClientID = "custom-client-id"
|
||||||
var seenClientIDs []string
|
var seenClientIDs []string
|
||||||
@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
||||||
wantClientID := openai.SoraClientID
|
wantClientID := "custom-exchange-client-id"
|
||||||
errCh := make(chan string, 1)
|
errCh := make(chan string, 1)
|
||||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = r.ParseForm()
|
_ = r.ParseForm()
|
||||||
|
|||||||
@ -1,98 +0,0 @@
|
|||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
)
|
|
||||||
|
|
||||||
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
|
|
||||||
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
|
|
||||||
//
|
|
||||||
// 设计说明:
|
|
||||||
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
|
|
||||||
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
|
|
||||||
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
|
|
||||||
type soraAccountRepository struct {
|
|
||||||
sql *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
|
|
||||||
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
|
|
||||||
return &soraAccountRepository{sql: sqlDB}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsert 创建或更新 Sora 账号扩展信息
|
|
||||||
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
|
|
||||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
|
||||||
accessToken, accessOK := updates["access_token"].(string)
|
|
||||||
refreshToken, refreshOK := updates["refresh_token"].(string)
|
|
||||||
sessionToken, sessionOK := updates["session_token"].(string)
|
|
||||||
|
|
||||||
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
|
|
||||||
if !sessionOK {
|
|
||||||
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
|
|
||||||
}
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_accounts
|
|
||||||
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
|
|
||||||
updated_at = NOW()
|
|
||||||
WHERE account_id = $1
|
|
||||||
`, accountID, sessionToken)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rows, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if rows == 0 {
|
|
||||||
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := r.sql.ExecContext(ctx, `
|
|
||||||
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
|
|
||||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
|
||||||
ON CONFLICT (account_id) DO UPDATE SET
|
|
||||||
access_token = EXCLUDED.access_token,
|
|
||||||
refresh_token = EXCLUDED.refresh_token,
|
|
||||||
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
|
|
||||||
updated_at = NOW()
|
|
||||||
`, accountID, accessToken, refreshToken, sessionToken)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
|
||||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
|
||||||
rows, err := r.sql.QueryContext(ctx, `
|
|
||||||
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
|
|
||||||
FROM sora_accounts
|
|
||||||
WHERE account_id = $1
|
|
||||||
`, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() { _ = rows.Close() }()
|
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, nil // 记录不存在
|
|
||||||
}
|
|
||||||
|
|
||||||
var sa service.SoraAccount
|
|
||||||
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &sa, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 删除 Sora 账号扩展信息
|
|
||||||
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
|
|
||||||
_, err := r.sql.ExecContext(ctx, `
|
|
||||||
DELETE FROM sora_accounts WHERE account_id = $1
|
|
||||||
`, accountID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@ -1,419 +0,0 @@
|
|||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
)
|
|
||||||
|
|
||||||
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
|
|
||||||
// 使用原生 SQL 操作 sora_generations 表。
|
|
||||||
type soraGenerationRepository struct {
|
|
||||||
sql *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
|
|
||||||
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
|
|
||||||
return &soraGenerationRepository{sql: sqlDB}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
|
|
||||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
|
||||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
|
||||||
|
|
||||||
err := r.sql.QueryRowContext(ctx, `
|
|
||||||
INSERT INTO sora_generations (
|
|
||||||
user_id, api_key_id, model, prompt, media_type,
|
|
||||||
status, media_url, media_urls, file_size_bytes,
|
|
||||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
||||||
RETURNING id, created_at
|
|
||||||
`,
|
|
||||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
|
||||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
|
||||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
|
||||||
).Scan(&gen.ID, &gen.CreatedAt)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
|
|
||||||
func (r *soraGenerationRepository) CreatePendingWithLimit(
|
|
||||||
ctx context.Context,
|
|
||||||
gen *service.SoraGeneration,
|
|
||||||
activeStatuses []string,
|
|
||||||
maxActive int64,
|
|
||||||
) error {
|
|
||||||
if gen == nil {
|
|
||||||
return fmt.Errorf("generation is nil")
|
|
||||||
}
|
|
||||||
if maxActive <= 0 {
|
|
||||||
return r.Create(ctx, gen)
|
|
||||||
}
|
|
||||||
if len(activeStatuses) == 0 {
|
|
||||||
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := r.sql.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
|
|
||||||
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
placeholders := make([]string, len(activeStatuses))
|
|
||||||
args := make([]any, 0, 1+len(activeStatuses))
|
|
||||||
args = append(args, gen.UserID)
|
|
||||||
for i, s := range activeStatuses {
|
|
||||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
|
||||||
args = append(args, s)
|
|
||||||
}
|
|
||||||
countQuery := fmt.Sprintf(
|
|
||||||
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
|
|
||||||
strings.Join(placeholders, ","),
|
|
||||||
)
|
|
||||||
var activeCount int64
|
|
||||||
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if activeCount >= maxActive {
|
|
||||||
return service.ErrSoraGenerationConcurrencyLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
|
||||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
|
||||||
if err := tx.QueryRowContext(ctx, `
|
|
||||||
INSERT INTO sora_generations (
|
|
||||||
user_id, api_key_id, model, prompt, media_type,
|
|
||||||
status, media_url, media_urls, file_size_bytes,
|
|
||||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
||||||
RETURNING id, created_at
|
|
||||||
`,
|
|
||||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
|
||||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
|
||||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
|
||||||
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
|
|
||||||
gen := &service.SoraGeneration{}
|
|
||||||
var mediaURLsJSON, s3KeysJSON []byte
|
|
||||||
var completedAt sql.NullTime
|
|
||||||
var apiKeyID sql.NullInt64
|
|
||||||
|
|
||||||
err := r.sql.QueryRowContext(ctx, `
|
|
||||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
|
||||||
status, media_url, media_urls, file_size_bytes,
|
|
||||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
|
||||||
created_at, completed_at
|
|
||||||
FROM sora_generations WHERE id = $1
|
|
||||||
`, id).Scan(
|
|
||||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
|
||||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
|
||||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
|
||||||
&gen.CreatedAt, &completedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, fmt.Errorf("生成记录不存在")
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if apiKeyID.Valid {
|
|
||||||
gen.APIKeyID = &apiKeyID.Int64
|
|
||||||
}
|
|
||||||
if completedAt.Valid {
|
|
||||||
gen.CompletedAt = &completedAt.Time
|
|
||||||
}
|
|
||||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
|
||||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
|
||||||
return gen, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
|
|
||||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
|
||||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
|
||||||
|
|
||||||
var completedAt *time.Time
|
|
||||||
if gen.CompletedAt != nil {
|
|
||||||
completedAt = gen.CompletedAt
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations SET
|
|
||||||
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
|
|
||||||
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
|
|
||||||
error_message = $9, completed_at = $10
|
|
||||||
WHERE id = $1
|
|
||||||
`,
|
|
||||||
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
|
||||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
|
|
||||||
gen.ErrorMessage, completedAt,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
|
|
||||||
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations
|
|
||||||
SET status = $2, upstream_task_id = $3
|
|
||||||
WHERE id = $1 AND status = $4
|
|
||||||
`,
|
|
||||||
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return affected > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
|
|
||||||
func (r *soraGenerationRepository) UpdateCompletedIfActive(
|
|
||||||
ctx context.Context,
|
|
||||||
id int64,
|
|
||||||
mediaURL string,
|
|
||||||
mediaURLs []string,
|
|
||||||
storageType string,
|
|
||||||
s3Keys []string,
|
|
||||||
fileSizeBytes int64,
|
|
||||||
completedAt time.Time,
|
|
||||||
) (bool, error) {
|
|
||||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
|
||||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations
|
|
||||||
SET status = $2,
|
|
||||||
media_url = $3,
|
|
||||||
media_urls = $4,
|
|
||||||
file_size_bytes = $5,
|
|
||||||
storage_type = $6,
|
|
||||||
s3_object_keys = $7,
|
|
||||||
error_message = '',
|
|
||||||
completed_at = $8
|
|
||||||
WHERE id = $1 AND status IN ($9, $10)
|
|
||||||
`,
|
|
||||||
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
|
|
||||||
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return affected > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
|
|
||||||
func (r *soraGenerationRepository) UpdateFailedIfActive(
|
|
||||||
ctx context.Context,
|
|
||||||
id int64,
|
|
||||||
errMsg string,
|
|
||||||
completedAt time.Time,
|
|
||||||
) (bool, error) {
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations
|
|
||||||
SET status = $2,
|
|
||||||
error_message = $3,
|
|
||||||
completed_at = $4
|
|
||||||
WHERE id = $1 AND status IN ($5, $6)
|
|
||||||
`,
|
|
||||||
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return affected > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
|
|
||||||
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations
|
|
||||||
SET status = $2, completed_at = $3
|
|
||||||
WHERE id = $1 AND status IN ($4, $5)
|
|
||||||
`,
|
|
||||||
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return affected > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
|
|
||||||
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
|
|
||||||
ctx context.Context,
|
|
||||||
id int64,
|
|
||||||
mediaURL string,
|
|
||||||
mediaURLs []string,
|
|
||||||
storageType string,
|
|
||||||
s3Keys []string,
|
|
||||||
fileSizeBytes int64,
|
|
||||||
) (bool, error) {
|
|
||||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
|
||||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
|
||||||
result, err := r.sql.ExecContext(ctx, `
|
|
||||||
UPDATE sora_generations
|
|
||||||
SET media_url = $2,
|
|
||||||
media_urls = $3,
|
|
||||||
file_size_bytes = $4,
|
|
||||||
storage_type = $5,
|
|
||||||
s3_object_keys = $6
|
|
||||||
WHERE id = $1 AND status = $7
|
|
||||||
`,
|
|
||||||
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return affected > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
|
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
|
|
||||||
// 构建 WHERE 条件
|
|
||||||
conditions := []string{"user_id = $1"}
|
|
||||||
args := []any{params.UserID}
|
|
||||||
argIdx := 2
|
|
||||||
|
|
||||||
if params.Status != "" {
|
|
||||||
// 支持逗号分隔的多状态
|
|
||||||
statuses := strings.Split(params.Status, ",")
|
|
||||||
placeholders := make([]string, len(statuses))
|
|
||||||
for i, s := range statuses {
|
|
||||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
|
||||||
args = append(args, strings.TrimSpace(s))
|
|
||||||
argIdx++
|
|
||||||
}
|
|
||||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
|
||||||
}
|
|
||||||
if params.StorageType != "" {
|
|
||||||
storageTypes := strings.Split(params.StorageType, ",")
|
|
||||||
placeholders := make([]string, len(storageTypes))
|
|
||||||
for i, s := range storageTypes {
|
|
||||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
|
||||||
args = append(args, strings.TrimSpace(s))
|
|
||||||
argIdx++
|
|
||||||
}
|
|
||||||
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
|
|
||||||
}
|
|
||||||
if params.MediaType != "" {
|
|
||||||
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
|
|
||||||
args = append(args, params.MediaType)
|
|
||||||
argIdx++
|
|
||||||
}
|
|
||||||
|
|
||||||
whereClause := "WHERE " + strings.Join(conditions, " AND ")
|
|
||||||
|
|
||||||
// 计数
|
|
||||||
var total int64
|
|
||||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
|
|
||||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分页查询
|
|
||||||
offset := (params.Page - 1) * params.PageSize
|
|
||||||
listQuery := fmt.Sprintf(`
|
|
||||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
|
||||||
status, media_url, media_urls, file_size_bytes,
|
|
||||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
|
||||||
created_at, completed_at
|
|
||||||
FROM sora_generations %s
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT $%d OFFSET $%d
|
|
||||||
`, whereClause, argIdx, argIdx+1)
|
|
||||||
args = append(args, params.PageSize, offset)
|
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = rows.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var results []*service.SoraGeneration
|
|
||||||
for rows.Next() {
|
|
||||||
gen := &service.SoraGeneration{}
|
|
||||||
var mediaURLsJSON, s3KeysJSON []byte
|
|
||||||
var completedAt sql.NullTime
|
|
||||||
var apiKeyID sql.NullInt64
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
|
||||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
|
||||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
|
||||||
&gen.CreatedAt, &completedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if apiKeyID.Valid {
|
|
||||||
gen.APIKeyID = &apiKeyID.Int64
|
|
||||||
}
|
|
||||||
if completedAt.Valid {
|
|
||||||
gen.CompletedAt = &completedAt.Time
|
|
||||||
}
|
|
||||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
|
||||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
|
||||||
results = append(results, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, total, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
|
|
||||||
if len(statuses) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
placeholders := make([]string, len(statuses))
|
|
||||||
args := []any{userID}
|
|
||||||
for i, s := range statuses {
|
|
||||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
|
||||||
args = append(args, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
|
|
||||||
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
@ -28,7 +28,7 @@ import (
|
|||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||||
|
|
||||||
// usageLogInsertArgTypes must stay in the same order as:
|
// usageLogInsertArgTypes must stay in the same order as:
|
||||||
// 1. prepareUsageLogInsert().args
|
// 1. prepareUsageLogInsert().args
|
||||||
@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{
|
|||||||
"text", // ip_address
|
"text", // ip_address
|
||||||
"integer", // image_count
|
"integer", // image_count
|
||||||
"text", // image_size
|
"text", // image_size
|
||||||
"text", // media_type
|
|
||||||
"text", // service_tier
|
"text", // service_tier
|
||||||
"text", // reasoning_effort
|
"text", // reasoning_effort
|
||||||
"text", // inbound_endpoint
|
"text", // inbound_endpoint
|
||||||
@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(keys)*47)
|
args := make([]any, 0, len(keys)*46)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, key := range keys {
|
for idx, key := range keys {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(preparedList)*46)
|
args := make([]any, 0, len(preparedList)*45)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, prepared := range preparedList {
|
for idx, prepared := range preparedList {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
ip_address,
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
media_type,
|
|
||||||
service_tier,
|
service_tier,
|
||||||
reasoning_effort,
|
reasoning_effort,
|
||||||
inbound_endpoint,
|
inbound_endpoint,
|
||||||
@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`, prepared.args...)
|
`, prepared.args...)
|
||||||
@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
userAgent := nullString(log.UserAgent)
|
userAgent := nullString(log.UserAgent)
|
||||||
ipAddress := nullString(log.IPAddress)
|
ipAddress := nullString(log.IPAddress)
|
||||||
imageSize := nullString(log.ImageSize)
|
imageSize := nullString(log.ImageSize)
|
||||||
mediaType := nullString(log.MediaType)
|
|
||||||
serviceTier := nullString(log.ServiceTier)
|
serviceTier := nullString(log.ServiceTier)
|
||||||
reasoningEffort := nullString(log.ReasoningEffort)
|
reasoningEffort := nullString(log.ReasoningEffort)
|
||||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||||
@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
ipAddress,
|
ipAddress,
|
||||||
log.ImageCount,
|
log.ImageCount,
|
||||||
imageSize,
|
imageSize,
|
||||||
mediaType,
|
|
||||||
serviceTier,
|
serviceTier,
|
||||||
reasoningEffort,
|
reasoningEffort,
|
||||||
inboundEndpoint,
|
inboundEndpoint,
|
||||||
@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
ipAddress sql.NullString
|
ipAddress sql.NullString
|
||||||
imageCount int
|
imageCount int
|
||||||
imageSize sql.NullString
|
imageSize sql.NullString
|
||||||
mediaType sql.NullString
|
|
||||||
serviceTier sql.NullString
|
serviceTier sql.NullString
|
||||||
reasoningEffort sql.NullString
|
reasoningEffort sql.NullString
|
||||||
inboundEndpoint sql.NullString
|
inboundEndpoint sql.NullString
|
||||||
@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&ipAddress,
|
&ipAddress,
|
||||||
&imageCount,
|
&imageCount,
|
||||||
&imageSize,
|
&imageSize,
|
||||||
&mediaType,
|
|
||||||
&serviceTier,
|
&serviceTier,
|
||||||
&reasoningEffort,
|
&reasoningEffort,
|
||||||
&inboundEndpoint,
|
&inboundEndpoint,
|
||||||
@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if imageSize.Valid {
|
if imageSize.Valid {
|
||||||
log.ImageSize = &imageSize.String
|
log.ImageSize = &imageSize.String
|
||||||
}
|
}
|
||||||
if mediaType.Valid {
|
|
||||||
log.MediaType = &mediaType.String
|
|
||||||
}
|
|
||||||
if serviceTier.Valid {
|
if serviceTier.Valid {
|
||||||
log.ServiceTier = &serviceTier.String
|
log.ServiceTier = &serviceTier.String
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // ip_address
|
sqlmock.AnyArg(), // ip_address
|
||||||
log.ImageCount,
|
log.ImageCount,
|
||||||
sqlmock.AnyArg(), // image_size
|
sqlmock.AnyArg(), // image_size
|
||||||
sqlmock.AnyArg(), // media_type
|
|
||||||
sqlmock.AnyArg(), // service_tier
|
sqlmock.AnyArg(), // service_tier
|
||||||
sqlmock.AnyArg(), // reasoning_effort
|
sqlmock.AnyArg(), // reasoning_effort
|
||||||
sqlmock.AnyArg(), // inbound_endpoint
|
sqlmock.AnyArg(), // inbound_endpoint
|
||||||
@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
log.ImageCount,
|
log.ImageCount,
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
sqlmock.AnyArg(),
|
|
||||||
serviceTier,
|
serviceTier,
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
0,
|
0,
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
0,
|
0,
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
|
||||||
sql.NullString{Valid: true, String: "flex"},
|
sql.NullString{Valid: true, String: "flex"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
0,
|
0,
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
|||||||
@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
SetBalance(userIn.Balance).
|
SetBalance(userIn.Balance).
|
||||||
SetConcurrency(userIn.Concurrency).
|
SetConcurrency(userIn.Concurrency).
|
||||||
SetStatus(userIn.Status).
|
SetStatus(userIn.Status).
|
||||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||||
@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
SetBalance(userIn.Balance).
|
SetBalance(userIn.Balance).
|
||||||
SetConcurrency(userIn.Concurrency).
|
SetConcurrency(userIn.Concurrency).
|
||||||
SetStatus(userIn.Status).
|
SetStatus(userIn.Status).
|
||||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
|
||||||
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
|
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||||
@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
|
|
||||||
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
|
|
||||||
if deltaBytes <= 0 {
|
|
||||||
user, err := r.GetByID(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return user.SoraStorageUsedBytes, nil
|
|
||||||
}
|
|
||||||
var newUsed int64
|
|
||||||
err := scanSingleRow(ctx, r.sql, `
|
|
||||||
UPDATE users
|
|
||||||
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
|
|
||||||
WHERE id = $1
|
|
||||||
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
|
|
||||||
RETURNING sora_storage_used_bytes
|
|
||||||
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
|
|
||||||
if err == nil {
|
|
||||||
return newUsed, nil
|
|
||||||
}
|
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
// 区分用户不存在和配额冲突
|
|
||||||
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
|
|
||||||
if existsErr != nil {
|
|
||||||
return 0, existsErr
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return 0, service.ErrUserNotFound
|
|
||||||
}
|
|
||||||
return 0, service.ErrSoraStorageQuotaExceeded
|
|
||||||
}
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
|
|
||||||
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
|
|
||||||
if deltaBytes <= 0 {
|
|
||||||
user, err := r.GetByID(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return user.SoraStorageUsedBytes, nil
|
|
||||||
}
|
|
||||||
var newUsed int64
|
|
||||||
err := scanSingleRow(ctx, r.sql, `
|
|
||||||
UPDATE users
|
|
||||||
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
|
|
||||||
WHERE id = $1
|
|
||||||
RETURNING sora_storage_used_bytes
|
|
||||||
`, []any{userID, deltaBytes}, &newUsed)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return 0, service.ErrUserNotFound
|
|
||||||
}
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return newUsed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAPIKeyRepository,
|
NewAPIKeyRepository,
|
||||||
NewGroupRepository,
|
NewGroupRepository,
|
||||||
NewAccountRepository,
|
NewAccountRepository,
|
||||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
|
||||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||||
NewProxyRepository,
|
NewProxyRepository,
|
||||||
|
|||||||
@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"image_price_1k": null,
|
"image_price_1k": null,
|
||||||
"image_price_2k": null,
|
"image_price_2k": null,
|
||||||
"image_price_4k": null,
|
"image_price_4k": null,
|
||||||
"sora_image_price_360": null,
|
|
||||||
"sora_image_price_540": null,
|
|
||||||
"sora_storage_quota_bytes": 0,
|
|
||||||
"sora_video_price_per_request": null,
|
|
||||||
"sora_video_price_per_request_hd": null,
|
|
||||||
"claude_code_only": false,
|
"claude_code_only": false,
|
||||||
"allow_messages_dispatch": false,
|
"allow_messages_dispatch": false,
|
||||||
"fallback_group_id": null,
|
"fallback_group_id": null,
|
||||||
@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"fallback_model_openai": "gpt-4o",
|
"fallback_model_openai": "gpt-4o",
|
||||||
"enable_identity_patch": true,
|
"enable_identity_patch": true,
|
||||||
"identity_patch_prompt": "",
|
"identity_patch_prompt": "",
|
||||||
"sora_client_enabled": false,
|
|
||||||
"invitation_code_enabled": false,
|
"invitation_code_enabled": false,
|
||||||
"home_content": "",
|
"home_content": "",
|
||||||
"hide_ccs_import_button": false,
|
"hide_ccs_import_button": false,
|
||||||
@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
jwtAuth := func(c *gin.Context) {
|
jwtAuth := func(c *gin.Context) {
|
||||||
|
|||||||
@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
|
|||||||
return strings.HasPrefix(path, "/v1/") ||
|
return strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/sora/") ||
|
|
||||||
strings.HasPrefix(path, "/responses")
|
strings.HasPrefix(path, "/responses")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,6 @@ func registerRoutes(
|
|||||||
// 注册各模块路由
|
// 注册各模块路由
|
||||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
|
||||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,8 +34,6 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// OpenAI OAuth
|
// OpenAI OAuth
|
||||||
registerOpenAIOAuthRoutes(admin, h)
|
registerOpenAIOAuthRoutes(admin, h)
|
||||||
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
|
||||||
registerSoraOAuthRoutes(admin, h)
|
|
||||||
|
|
||||||
// Gemini OAuth
|
// Gemini OAuth
|
||||||
registerGeminiOAuthRoutes(admin, h)
|
registerGeminiOAuthRoutes(admin, h)
|
||||||
@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|
||||||
sora := admin.Group("/sora")
|
|
||||||
{
|
|
||||||
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
|
||||||
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
|
||||||
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
|
||||||
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
|
||||||
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
|
||||||
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
|
||||||
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
gemini := admin.Group("/gemini")
|
gemini := admin.Group("/gemini")
|
||||||
{
|
{
|
||||||
@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
// Beta 策略配置
|
// Beta 策略配置
|
||||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||||
// Sora S3 存储配置
|
|
||||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
|
||||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
|
||||||
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
|
|
||||||
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
|
|
||||||
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
|
|
||||||
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
|
|
||||||
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
|
|
||||||
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) {
|
) {
|
||||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||||
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
|
|
||||||
if soraMaxBodySize <= 0 {
|
|
||||||
soraMaxBodySize = cfg.Gateway.MaxBodySize
|
|
||||||
}
|
|
||||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
|
||||||
clientRequestID := middleware.ClientRequestID()
|
clientRequestID := middleware.ClientRequestID()
|
||||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||||
endpointNorm := handler.InboundEndpointMiddleware()
|
endpointNorm := handler.InboundEndpointMiddleware()
|
||||||
@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
|
|||||||
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sora 专用路由(强制使用 sora 平台)
|
|
||||||
soraV1 := r.Group("/sora/v1")
|
|
||||||
soraV1.Use(soraBodyLimit)
|
|
||||||
soraV1.Use(clientRequestID)
|
|
||||||
soraV1.Use(opsErrorLogger)
|
|
||||||
soraV1.Use(endpointNorm)
|
|
||||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
|
||||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
|
||||||
soraV1.Use(requireGroupAnthropic)
|
|
||||||
{
|
|
||||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
|
||||||
soraV1.GET("/models", h.Gateway.Models)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora 媒体代理(可选 API Key 验证)
|
|
||||||
if cfg.Gateway.SoraMediaRequireAPIKey {
|
|
||||||
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
|
|
||||||
} else {
|
|
||||||
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
|
|
||||||
}
|
|
||||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
|
||||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||||
|
|||||||
@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine {
|
|||||||
&handler.Handlers{
|
&handler.Handlers{
|
||||||
Gateway: &handler.GatewayHandler{},
|
Gateway: &handler.GatewayHandler{},
|
||||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||||
SoraGateway: &handler.SoraGatewayHandler{},
|
|
||||||
},
|
},
|
||||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@ -1,36 +0,0 @@
|
|||||||
package routes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
|
|
||||||
func RegisterSoraClientRoutes(
|
|
||||||
v1 *gin.RouterGroup,
|
|
||||||
h *handler.Handlers,
|
|
||||||
jwtAuth middleware.JWTAuthMiddleware,
|
|
||||||
settingService *service.SettingService,
|
|
||||||
) {
|
|
||||||
if h.SoraClient == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticated := v1.Group("/sora")
|
|
||||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
|
||||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
|
||||||
{
|
|
||||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
|
||||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
|
||||||
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
|
|
||||||
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
|
|
||||||
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
|
|
||||||
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
|
|
||||||
authenticated.GET("/quota", h.SoraClient.GetQuota)
|
|
||||||
authenticated.GET("/models", h.SoraClient.GetModels)
|
|
||||||
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -28,8 +28,7 @@ type AccountRepository interface {
|
|||||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||||
// Returns (nil, nil) if not found.
|
// Returns (nil, nil) if not found.
|
||||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||||
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
// FindByExtraField 根据 extra 字段中的键值对查找账号
|
||||||
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
|
||||||
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||||
// for all accounts that have been synced from CRS.
|
// for all accounts that have been synced from CRS.
|
||||||
|
|||||||
@ -13,18 +13,14 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
|||||||
const (
|
const (
|
||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
|
||||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
|
||||||
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
|
|
||||||
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
|
|
||||||
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
@ -71,13 +62,8 @@ type AccountTestService struct {
|
|||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
tlsFPProfileService *TLSFingerprintProfileService
|
tlsFPProfileService *TLSFingerprintProfileService
|
||||||
soraTestGuardMu sync.Mutex
|
|
||||||
soraTestLastRun map[int64]time.Time
|
|
||||||
soraTestCooldown time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultSoraTestCooldown = 10 * time.Second
|
|
||||||
|
|
||||||
// NewAccountTestService creates a new AccountTestService
|
// NewAccountTestService creates a new AccountTestService
|
||||||
func NewAccountTestService(
|
func NewAccountTestService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
@ -94,8 +80,6 @@ func NewAccountTestService(
|
|||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
tlsFPProfileService: tlsFPProfileService,
|
tlsFPProfileService: tlsFPProfileService,
|
||||||
soraTestLastRun: make(map[int64]time.Time),
|
|
||||||
soraTestCooldown: defaultSoraTestCooldown,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
return s.testSoraAccountConnection(c, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.testClaudeAccountConnection(c, account, modelID)
|
return s.testClaudeAccountConnection(c, account, modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
|||||||
return s.processGeminiStream(c, resp.Body)
|
return s.processGeminiStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
type soraProbeStep struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
HTTPStatus int `json:"http_status,omitempty"`
|
|
||||||
ErrorCode string `json:"error_code,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type soraProbeSummary struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Steps []soraProbeStep `json:"steps"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type soraProbeRecorder struct {
|
|
||||||
steps []soraProbeStep
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
|
|
||||||
r.steps = append(r.steps, soraProbeStep{
|
|
||||||
Name: name,
|
|
||||||
Status: status,
|
|
||||||
HTTPStatus: httpStatus,
|
|
||||||
ErrorCode: strings.TrimSpace(errorCode),
|
|
||||||
Message: strings.TrimSpace(message),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *soraProbeRecorder) finalize() soraProbeSummary {
|
|
||||||
meSuccess := false
|
|
||||||
partial := false
|
|
||||||
for _, step := range r.steps {
|
|
||||||
if step.Name == "me" {
|
|
||||||
meSuccess = strings.EqualFold(step.Status, "success")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.EqualFold(step.Status, "failed") {
|
|
||||||
partial = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status := "success"
|
|
||||||
if !meSuccess {
|
|
||||||
status = "failed"
|
|
||||||
} else if partial {
|
|
||||||
status = "partial_success"
|
|
||||||
}
|
|
||||||
|
|
||||||
return soraProbeSummary{
|
|
||||||
Status: status,
|
|
||||||
Steps: append([]soraProbeStep(nil), r.steps...),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
|
|
||||||
if rec == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
summary := rec.finalize()
|
|
||||||
code := ""
|
|
||||||
for _, step := range summary.Steps {
|
|
||||||
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
|
|
||||||
code = step.ErrorCode
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{
|
|
||||||
Type: "sora_test_result",
|
|
||||||
Status: summary.Status,
|
|
||||||
Code: code,
|
|
||||||
Data: summary,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
|
|
||||||
if accountID <= 0 {
|
|
||||||
return 0, true
|
|
||||||
}
|
|
||||||
s.soraTestGuardMu.Lock()
|
|
||||||
defer s.soraTestGuardMu.Unlock()
|
|
||||||
|
|
||||||
if s.soraTestLastRun == nil {
|
|
||||||
s.soraTestLastRun = make(map[int64]time.Time)
|
|
||||||
}
|
|
||||||
cooldown := s.soraTestCooldown
|
|
||||||
if cooldown <= 0 {
|
|
||||||
cooldown = defaultSoraTestCooldown
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
|
|
||||||
elapsed := now.Sub(lastRun)
|
|
||||||
if elapsed < cooldown {
|
|
||||||
return cooldown - elapsed, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.soraTestLastRun[accountID] = now
|
|
||||||
return 0, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func ceilSeconds(d time.Duration) int {
|
|
||||||
if d <= 0 {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
sec := int(d / time.Second)
|
|
||||||
if d%time.Second != 0 {
|
|
||||||
sec++
|
|
||||||
}
|
|
||||||
if sec < 1 {
|
|
||||||
sec = 1
|
|
||||||
}
|
|
||||||
return sec
|
|
||||||
}
|
|
||||||
|
|
||||||
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
|
|
||||||
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
|
|
||||||
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
|
|
||||||
ctx := c.Request.Context()
|
|
||||||
|
|
||||||
apiKey := account.GetCredential("api_key")
|
|
||||||
if apiKey == "" {
|
|
||||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := account.GetBaseURL()
|
|
||||||
if baseURL == "" {
|
|
||||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 base_url 格式
|
|
||||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
|
|
||||||
}
|
|
||||||
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
|
|
||||||
|
|
||||||
// 设置 SSE 头
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
c.Writer.Flush()
|
|
||||||
|
|
||||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
|
||||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
|
||||||
return s.sendErrorAndEnd(c, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
|
|
||||||
|
|
||||||
// 构建轻量级 prompt-enhance 请求作为连通性测试
|
|
||||||
testPayload := map[string]any{
|
|
||||||
"model": "prompt-enhance-short-10s",
|
|
||||||
"messages": []map[string]string{{"role": "user", "content": "test"}},
|
|
||||||
"stream": false,
|
|
||||||
}
|
|
||||||
payloadBytes, _ := json.Marshal(testPayload)
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, "构建测试请求失败")
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
|
|
||||||
// 获取代理 URL
|
|
||||||
proxyURL := ""
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
|
|
||||||
if resp.StatusCode == http.StatusBadRequest {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
|
||||||
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
|
|
||||||
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
|
|
||||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
|
||||||
// apikey 类型走独立测试流程
|
|
||||||
if account.Type == AccountTypeAPIKey {
|
|
||||||
return s.testSoraAPIKeyAccountConnection(c, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := c.Request.Context()
|
|
||||||
recorder := &soraProbeRecorder{}
|
|
||||||
|
|
||||||
authToken := account.GetCredential("access_token")
|
|
||||||
if authToken == "" {
|
|
||||||
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, "No access token available")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set SSE headers
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
c.Writer.Flush()
|
|
||||||
|
|
||||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
|
||||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
|
||||||
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send test_start event
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 Sora 客户端标准请求头
|
|
||||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
|
||||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
|
||||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
|
||||||
|
|
||||||
// Get proxy URL
|
|
||||||
proxyURL := ""
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
soraTLSProfile := s.resolveSoraTLSProfile()
|
|
||||||
|
|
||||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
|
|
||||||
if err != nil {
|
|
||||||
recorder.addStep("me", "failed", 0, "network_error", err.Error())
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
|
||||||
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
|
||||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
|
|
||||||
}
|
|
||||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
|
||||||
switch {
|
|
||||||
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
|
|
||||||
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
|
|
||||||
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
|
|
||||||
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
|
|
||||||
case strings.TrimSpace(upstreamMessage) != "":
|
|
||||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
|
|
||||||
default:
|
|
||||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
|
|
||||||
|
|
||||||
// 解析 /me 响应,提取用户信息
|
|
||||||
var meResp map[string]any
|
|
||||||
if err := json.Unmarshal(body, &meResp); err != nil {
|
|
||||||
// 能收到 200 就说明 token 有效
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
|
|
||||||
} else {
|
|
||||||
// 尝试提取用户名或邮箱信息
|
|
||||||
info := "Sora connection OK"
|
|
||||||
if name, ok := meResp["name"].(string); ok && name != "" {
|
|
||||||
info = fmt.Sprintf("Sora connection OK - User: %s", name)
|
|
||||||
} else if email, ok := meResp["email"].(string); ok && email != "" {
|
|
||||||
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
|
||||||
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
|
||||||
if err == nil {
|
|
||||||
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
|
||||||
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
subReq.Header.Set("Accept", "application/json")
|
|
||||||
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
|
||||||
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
|
|
||||||
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
|
|
||||||
|
|
||||||
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
|
|
||||||
if subErr != nil {
|
|
||||||
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
|
||||||
} else {
|
|
||||||
subBody, _ := io.ReadAll(subResp.Body)
|
|
||||||
_ = subResp.Body.Close()
|
|
||||||
if subResp.StatusCode == http.StatusOK {
|
|
||||||
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
|
|
||||||
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
|
||||||
} else {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
|
|
||||||
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
|
||||||
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
|
|
||||||
} else {
|
|
||||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
|
|
||||||
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
|
|
||||||
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
|
|
||||||
|
|
||||||
s.emitSoraProbeSummary(c, recorder)
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) testSora2Capabilities(
|
|
||||||
c *gin.Context,
|
|
||||||
ctx context.Context,
|
|
||||||
account *Account,
|
|
||||||
authToken string,
|
|
||||||
proxyURL string,
|
|
||||||
tlsProfile *tlsfingerprint.Profile,
|
|
||||||
recorder *soraProbeRecorder,
|
|
||||||
) {
|
|
||||||
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
|
|
||||||
ctx,
|
|
||||||
account,
|
|
||||||
authToken,
|
|
||||||
soraInviteMineURL,
|
|
||||||
proxyURL,
|
|
||||||
tlsProfile,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if inviteStatus == http.StatusUnauthorized {
|
|
||||||
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
|
|
||||||
ctx,
|
|
||||||
account,
|
|
||||||
authToken,
|
|
||||||
soraBootstrapURL,
|
|
||||||
proxyURL,
|
|
||||||
tlsProfile,
|
|
||||||
)
|
|
||||||
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
|
|
||||||
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
|
|
||||||
ctx,
|
|
||||||
account,
|
|
||||||
authToken,
|
|
||||||
soraInviteMineURL,
|
|
||||||
proxyURL,
|
|
||||||
tlsProfile,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if recorder != nil {
|
|
||||||
code := ""
|
|
||||||
msg := ""
|
|
||||||
if bootstrapErr != nil {
|
|
||||||
code = "network_error"
|
|
||||||
msg = bootstrapErr.Error()
|
|
||||||
}
|
|
||||||
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if inviteStatus != http.StatusOK {
|
|
||||||
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
|
|
||||||
}
|
|
||||||
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
|
|
||||||
}
|
|
||||||
|
|
||||||
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
|
||||||
} else {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
|
|
||||||
}
|
|
||||||
|
|
||||||
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
|
|
||||||
ctx,
|
|
||||||
account,
|
|
||||||
authToken,
|
|
||||||
soraRemainingURL,
|
|
||||||
proxyURL,
|
|
||||||
tlsProfile,
|
|
||||||
)
|
|
||||||
if remainingErr != nil {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if remainingStatus != http.StatusOK {
|
|
||||||
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
|
|
||||||
}
|
|
||||||
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
|
|
||||||
}
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if recorder != nil {
|
|
||||||
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
|
|
||||||
}
|
|
||||||
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
|
||||||
} else {
|
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) fetchSoraTestEndpoint(
|
|
||||||
ctx context.Context,
|
|
||||||
account *Account,
|
|
||||||
authToken string,
|
|
||||||
url string,
|
|
||||||
proxyURL string,
|
|
||||||
tlsProfile *tlsfingerprint.Profile,
|
|
||||||
) (int, http.Header, []byte, error) {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
|
||||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
|
||||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
|
||||||
|
|
||||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, nil, err
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
body, readErr := io.ReadAll(resp.Body)
|
|
||||||
if readErr != nil {
|
|
||||||
return resp.StatusCode, resp.Header, nil, readErr
|
|
||||||
}
|
|
||||||
return resp.StatusCode, resp.Header, body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSoraSubscriptionSummary(body []byte) string {
|
|
||||||
var subResp struct {
|
|
||||||
Data []struct {
|
|
||||||
Plan struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
} `json:"plan"`
|
|
||||||
EndTS string `json:"end_ts"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &subResp); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if len(subResp.Data) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
first := subResp.Data[0]
|
|
||||||
parts := make([]string, 0, 3)
|
|
||||||
if first.Plan.Title != "" {
|
|
||||||
parts = append(parts, first.Plan.Title)
|
|
||||||
}
|
|
||||||
if first.Plan.ID != "" {
|
|
||||||
parts = append(parts, first.Plan.ID)
|
|
||||||
}
|
|
||||||
if first.EndTS != "" {
|
|
||||||
parts = append(parts, "end="+first.EndTS)
|
|
||||||
}
|
|
||||||
if len(parts) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return "Subscription: " + strings.Join(parts, " | ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSoraInviteSummary(body []byte) string {
|
|
||||||
var inviteResp struct {
|
|
||||||
InviteCode string `json:"invite_code"`
|
|
||||||
RedeemedCount int64 `json:"redeemed_count"`
|
|
||||||
TotalCount int64 `json:"total_count"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &inviteResp); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := []string{"Sora2: supported"}
|
|
||||||
if inviteResp.InviteCode != "" {
|
|
||||||
parts = append(parts, "invite="+inviteResp.InviteCode)
|
|
||||||
}
|
|
||||||
if inviteResp.TotalCount > 0 {
|
|
||||||
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
|
|
||||||
}
|
|
||||||
return strings.Join(parts, " | ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSoraRemainingSummary(body []byte) string {
|
|
||||||
var remainingResp struct {
|
|
||||||
RateLimitAndCreditBalance struct {
|
|
||||||
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
|
|
||||||
RateLimitReached bool `json:"rate_limit_reached"`
|
|
||||||
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
|
|
||||||
} `json:"rate_limit_and_credit_balance"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &remainingResp); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
info := remainingResp.RateLimitAndCreditBalance
|
|
||||||
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
|
|
||||||
if info.RateLimitReached {
|
|
||||||
parts = append(parts, "rate_limited=true")
|
|
||||||
}
|
|
||||||
if info.AccessResetsInSeconds > 0 {
|
|
||||||
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
|
|
||||||
}
|
|
||||||
return strings.Join(parts, " | ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
|
|
||||||
if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
|
|
||||||
// Sora TLS fingerprint enabled — use built-in default profile
|
|
||||||
return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
|
|
||||||
}
|
|
||||||
return nil // disabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
|
||||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
|
||||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
|
||||||
return soraerror.ExtractCloudflareRayID(headers, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractSoraEgressIPHint(headers http.Header) string {
|
|
||||||
if headers == nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
candidates := []string{
|
|
||||||
"x-openai-public-ip",
|
|
||||||
"x-envoy-external-address",
|
|
||||||
"cf-connecting-ip",
|
|
||||||
"x-forwarded-for",
|
|
||||||
}
|
|
||||||
for _, key := range candidates {
|
|
||||||
if value := strings.TrimSpace(headers.Get(key)); value != "" {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeProxyURLForLog(raw string) string {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
u, err := url.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return "<invalid_proxy_url>"
|
|
||||||
}
|
|
||||||
if u.User != nil {
|
|
||||||
u.User = nil
|
|
||||||
}
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func endpointPathForLog(endpoint string) string {
|
|
||||||
parsed, err := url.Parse(strings.TrimSpace(endpoint))
|
|
||||||
if err != nil || parsed.Path == "" {
|
|
||||||
return endpoint
|
|
||||||
}
|
|
||||||
return parsed.Path
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
|
|
||||||
accountID := int64(0)
|
|
||||||
platform := ""
|
|
||||||
proxyID := "none"
|
|
||||||
if account != nil {
|
|
||||||
accountID = account.ID
|
|
||||||
platform = account.Platform
|
|
||||||
if account.ProxyID != nil {
|
|
||||||
proxyID = fmt.Sprintf("%d", *account.ProxyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cfRay := extractCloudflareRayID(headers, body)
|
|
||||||
if cfRay == "" {
|
|
||||||
cfRay = "unknown"
|
|
||||||
}
|
|
||||||
log.Printf(
|
|
||||||
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
|
|
||||||
accountID,
|
|
||||||
platform,
|
|
||||||
endpoint,
|
|
||||||
endpointPathForLog(endpoint),
|
|
||||||
proxyID,
|
|
||||||
sanitizeProxyURLForLog(proxyURL),
|
|
||||||
cfRay,
|
|
||||||
extractSoraEgressIPHint(headers),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func truncateSoraErrorBody(body []byte, max int) string {
|
|
||||||
return soraerror.TruncateBody(body, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||||
|
|||||||
@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
ctx, recorder := newSoraTestContext()
|
ctx, recorder := newTestContext()
|
||||||
svc := &AccountTestService{}
|
svc := &AccountTestService{}
|
||||||
|
|
||||||
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||||
|
|||||||
@ -4,16 +4,61 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// --- shared test helpers ---
|
||||||
|
|
||||||
|
type queuedHTTPUpstream struct {
|
||||||
|
responses []*http.Response
|
||||||
|
requests []*http.Request
|
||||||
|
tlsFlags []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return nil, fmt.Errorf("unexpected Do call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||||
|
u.requests = append(u.requests, req)
|
||||||
|
u.tlsFlags = append(u.tlsFlags, profile != nil)
|
||||||
|
if len(u.responses) == 0 {
|
||||||
|
return nil, fmt.Errorf("no mocked response")
|
||||||
|
}
|
||||||
|
resp := u.responses[0]
|
||||||
|
u.responses = u.responses[1:]
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONResponse(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- test functions ---
|
||||||
|
|
||||||
|
func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||||
|
return c, rec
|
||||||
|
}
|
||||||
|
|
||||||
type openAIAccountTestRepo struct {
|
type openAIAccountTestRepo struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
updatedExtra map[string]any
|
updatedExtra map[string]any
|
||||||
@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
|
|||||||
|
|
||||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
ctx, recorder := newSoraTestContext()
|
ctx, recorder := newTestContext()
|
||||||
|
|
||||||
resp := newJSONResponse(http.StatusOK, "")
|
resp := newJSONResponse(http.StatusOK, "")
|
||||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||||
@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
|||||||
|
|
||||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
ctx, _ := newSoraTestContext()
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||||
|
|||||||
@ -1,320 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type queuedHTTPUpstream struct {
|
|
||||||
responses []*http.Response
|
|
||||||
requests []*http.Request
|
|
||||||
tlsFlags []bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
|
||||||
return nil, fmt.Errorf("unexpected Do call")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
|
||||||
u.requests = append(u.requests, req)
|
|
||||||
u.tlsFlags = append(u.tlsFlags, profile != nil)
|
|
||||||
if len(u.responses) == 0 {
|
|
||||||
return nil, fmt.Errorf("no mocked response")
|
|
||||||
}
|
|
||||||
resp := u.responses[0]
|
|
||||||
u.responses = u.responses[1:]
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newJSONResponse(status int, body string) *http.Response {
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: status,
|
|
||||||
Header: make(http.Header),
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
|
||||||
resp := newJSONResponse(status, body)
|
|
||||||
resp.Header.Set(key, value)
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(rec)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
|
||||||
return c, rec
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
|
||||||
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
|
||||||
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
|
|
||||||
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{
|
|
||||||
httpUpstream: upstream,
|
|
||||||
cfg: &config.Config{
|
|
||||||
Gateway: config.GatewayConfig{
|
|
||||||
TLSFingerprint: config.TLSFingerprintConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
DisableTLSFingerprint: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, upstream.requests, 4)
|
|
||||||
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
|
||||||
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
|
||||||
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
|
|
||||||
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
|
|
||||||
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
|
||||||
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
|
||||||
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
|
|
||||||
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, `"type":"test_start"`)
|
|
||||||
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
|
||||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
|
||||||
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
|
||||||
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
|
||||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
|
||||||
require.Contains(t, body, `"status":"success"`)
|
|
||||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
|
||||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
|
||||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
|
|
||||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{httpUpstream: upstream}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, upstream.requests, 4)
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
|
||||||
require.Contains(t, body, "Subscription check returned 403")
|
|
||||||
require.Contains(t, body, "Sora2 invite check returned 401")
|
|
||||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
|
||||||
require.Contains(t, body, `"status":"partial_success"`)
|
|
||||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{httpUpstream: upstream}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
|
||||||
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, `"type":"error"`)
|
|
||||||
require.Contains(t, body, "Cloudflare challenge")
|
|
||||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{httpUpstream: upstream}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
|
||||||
require.Contains(t, err.Error(), "HTTP 429")
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, "Cloudflare challenge")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{httpUpstream: upstream}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "token_invalidated")
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
|
||||||
require.Contains(t, body, `"status":"failed"`)
|
|
||||||
require.Contains(t, body, "token_invalidated")
|
|
||||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{
|
|
||||||
httpUpstream: upstream,
|
|
||||||
soraTestCooldown: time.Hour,
|
|
||||||
}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c1, _ := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c1, account)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c2, rec2 := newSoraTestContext()
|
|
||||||
err = svc.testSoraAccountConnection(c2, account)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "测试过于频繁")
|
|
||||||
body := rec2.Body.String()
|
|
||||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
|
||||||
require.Contains(t, body, `"code":"test_rate_limited"`)
|
|
||||||
require.Contains(t, body, `"status":"failed"`)
|
|
||||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
|
||||||
upstream := &queuedHTTPUpstream{
|
|
||||||
responses: []*http.Response{
|
|
||||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
|
||||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
|
||||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &AccountTestService{httpUpstream: upstream}
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "test_token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
c, rec := newSoraTestContext()
|
|
||||||
err := svc.testSoraAccountConnection(c, account)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
|
||||||
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
|
|
||||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
|
||||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeProxyURLForLog(t *testing.T) {
|
|
||||||
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
|
||||||
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
|
||||||
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractSoraEgressIPHint(t *testing.T) {
|
|
||||||
h := make(http.Header)
|
|
||||||
h.Set("x-openai-public-ip", "203.0.113.10")
|
|
||||||
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
|
||||||
|
|
||||||
h2 := make(http.Header)
|
|
||||||
h2.Set("x-envoy-external-address", "198.51.100.9")
|
|
||||||
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
|
||||||
|
|
||||||
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
|
||||||
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
|
||||||
}
|
|
||||||
@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
"github.com/Wei-Shaw/sub2api/internal/util/httputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminService interface defines admin management operations
|
// AdminService interface defines admin management operations
|
||||||
@ -104,14 +104,13 @@ type AdminService interface {
|
|||||||
|
|
||||||
// CreateUserInput represents input for creating a new user via admin operations.
|
// CreateUserInput represents input for creating a new user via admin operations.
|
||||||
type CreateUserInput struct {
|
type CreateUserInput struct {
|
||||||
Email string
|
Email string
|
||||||
Password string
|
Password string
|
||||||
Username string
|
Username string
|
||||||
Notes string
|
Notes string
|
||||||
Balance float64
|
Balance float64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
AllowedGroups []int64
|
AllowedGroups []int64
|
||||||
SoraStorageQuotaBytes int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateUserInput struct {
|
type UpdateUserInput struct {
|
||||||
@ -125,8 +124,7 @@ type UpdateUserInput struct {
|
|||||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||||
GroupRates map[int64]*float64
|
GroupRates map[int64]*float64
|
||||||
SoraStorageQuotaBytes *int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateGroupInput struct {
|
type CreateGroupInput struct {
|
||||||
@ -140,16 +138,11 @@ type CreateGroupInput struct {
|
|||||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||||
ImagePrice1K *float64
|
ImagePrice1K *float64
|
||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
// Sora 按次计费配置
|
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||||
SoraImagePrice360 *float64
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
SoraImagePrice540 *float64
|
|
||||||
SoraVideoPricePerRequest *float64
|
|
||||||
SoraVideoPricePerRequestHD *float64
|
|
||||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
|
||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
@ -158,8 +151,6 @@ type CreateGroupInput struct {
|
|||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string
|
SupportedModelScopes []string
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes int64
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
AllowMessagesDispatch bool
|
AllowMessagesDispatch bool
|
||||||
DefaultMappedModel string
|
DefaultMappedModel string
|
||||||
@ -181,16 +172,11 @@ type UpdateGroupInput struct {
|
|||||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||||
ImagePrice1K *float64
|
ImagePrice1K *float64
|
||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
// Sora 按次计费配置
|
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||||
SoraImagePrice360 *float64
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
SoraImagePrice540 *float64
|
|
||||||
SoraVideoPricePerRequest *float64
|
|
||||||
SoraVideoPricePerRequestHD *float64
|
|
||||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
|
||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
@ -199,8 +185,6 @@ type UpdateGroupInput struct {
|
|||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string
|
SupportedModelScopes *[]string
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes *int64
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
AllowMessagesDispatch *bool
|
AllowMessagesDispatch *bool
|
||||||
DefaultMappedModel *string
|
DefaultMappedModel *string
|
||||||
@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
|
|||||||
http.StatusOK: {},
|
http.StatusOK: {},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Target: "sora",
|
|
||||||
URL: "https://sora.chatgpt.com/backend/me",
|
|
||||||
Method: http.MethodGet,
|
|
||||||
AllowedStatuses: map[int]struct{}{
|
|
||||||
http.StatusUnauthorized: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -448,7 +424,6 @@ type adminServiceImpl struct {
|
|||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
|
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
@ -473,7 +448,6 @@ func NewAdminService(
|
|||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
groupRepo GroupRepository,
|
groupRepo GroupRepository,
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
soraAccountRepo SoraAccountRepository,
|
|
||||||
proxyRepo ProxyRepository,
|
proxyRepo ProxyRepository,
|
||||||
apiKeyRepo APIKeyRepository,
|
apiKeyRepo APIKeyRepository,
|
||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
@ -492,7 +466,6 @@ func NewAdminService(
|
|||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
soraAccountRepo: soraAccountRepo,
|
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
@ -574,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
|
|||||||
|
|
||||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||||
user := &User{
|
user := &User{
|
||||||
Email: input.Email,
|
Email: input.Email,
|
||||||
Username: input.Username,
|
Username: input.Username,
|
||||||
Notes: input.Notes,
|
Notes: input.Notes,
|
||||||
Role: RoleUser, // Always create as regular user, never admin
|
Role: RoleUser, // Always create as regular user, never admin
|
||||||
Balance: input.Balance,
|
Balance: input.Balance,
|
||||||
Concurrency: input.Concurrency,
|
Concurrency: input.Concurrency,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
AllowedGroups: input.AllowedGroups,
|
AllowedGroups: input.AllowedGroups,
|
||||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
|
||||||
}
|
}
|
||||||
if err := user.SetPassword(input.Password); err != nil {
|
if err := user.SetPassword(input.Password); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
user.AllowedGroups = *input.AllowedGroups
|
user.AllowedGroups = *input.AllowedGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.SoraStorageQuotaBytes != nil {
|
|
||||||
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||||
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
|
|
||||||
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
|
|
||||||
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
|
|
||||||
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
|
|
||||||
|
|
||||||
// 校验降级分组
|
// 校验降级分组
|
||||||
if input.FallbackGroupID != nil {
|
if input.FallbackGroupID != nil {
|
||||||
@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
ImagePrice1K: imagePrice1K,
|
ImagePrice1K: imagePrice1K,
|
||||||
ImagePrice2K: imagePrice2K,
|
ImagePrice2K: imagePrice2K,
|
||||||
ImagePrice4K: imagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
SoraImagePrice360: soraImagePrice360,
|
|
||||||
SoraImagePrice540: soraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: soraVideoPrice,
|
|
||||||
SoraVideoPricePerRequestHD: soraVideoPriceHD,
|
|
||||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||||
FallbackGroupID: input.FallbackGroupID,
|
FallbackGroupID: input.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||||
ModelRouting: input.ModelRouting,
|
ModelRouting: input.ModelRouting,
|
||||||
MCPXMLInject: mcpXMLInject,
|
MCPXMLInject: mcpXMLInject,
|
||||||
SupportedModelScopes: input.SupportedModelScopes,
|
SupportedModelScopes: input.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
|
||||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||||
RequireOAuthOnly: input.RequireOAuthOnly,
|
RequireOAuthOnly: input.RequireOAuthOnly,
|
||||||
RequirePrivacySet: input.RequirePrivacySet,
|
RequirePrivacySet: input.RequirePrivacySet,
|
||||||
@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.ImagePrice4K != nil {
|
if input.ImagePrice4K != nil {
|
||||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||||
}
|
}
|
||||||
if input.SoraImagePrice360 != nil {
|
|
||||||
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
|
|
||||||
}
|
|
||||||
if input.SoraImagePrice540 != nil {
|
|
||||||
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
|
|
||||||
}
|
|
||||||
if input.SoraVideoPricePerRequest != nil {
|
|
||||||
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
|
|
||||||
}
|
|
||||||
if input.SoraVideoPricePerRequestHD != nil {
|
|
||||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
|
||||||
}
|
|
||||||
if input.SoraStorageQuotaBytes != nil {
|
|
||||||
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
if input.ClaudeCodeOnly != nil {
|
if input.ClaudeCodeOnly != nil {
|
||||||
@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sora apikey 账号的 base_url 必填校验
|
|
||||||
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
|
|
||||||
baseURL, _ := input.Credentials["base_url"].(string)
|
|
||||||
baseURL = strings.TrimSpace(baseURL)
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
|
||||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Notes: normalizeAccountNotes(input.Notes),
|
Notes: normalizeAccountNotes(input.Notes),
|
||||||
@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
|
|
||||||
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
|
|
||||||
soraUpdates := map[string]any{
|
|
||||||
"access_token": account.GetCredential("access_token"),
|
|
||||||
"refresh_token": account.GetCredential("refresh_token"),
|
|
||||||
}
|
|
||||||
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
|
|
||||||
// 只记录警告日志,不阻塞账号创建
|
|
||||||
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 绑定分组
|
// 绑定分组
|
||||||
if len(groupIDs) > 0 {
|
if len(groupIDs) > 0 {
|
||||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||||
@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sora apikey 账号的 base_url 必填校验
|
|
||||||
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
|
|
||||||
baseURL, _ := account.Credentials["base_url"].(string)
|
|
||||||
baseURL = strings.TrimSpace(baseURL)
|
|
||||||
if baseURL == "" {
|
|
||||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
|
||||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先验证分组是否存在(在任何写操作之前)
|
// 先验证分组是否存在(在任何写操作之前)
|
||||||
if input.GroupIDs != nil {
|
if input.GroupIDs != nil {
|
||||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||||
@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
|
|||||||
body = body[:proxyQualityMaxBodyBytes]
|
body = body[:proxyQualityMaxBodyBytes]
|
||||||
}
|
}
|
||||||
|
|
||||||
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
// Cloudflare challenge 检测
|
||||||
|
if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||||
item.Status = "challenge"
|
item.Status = "challenge"
|
||||||
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
|
item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
|
||||||
item.Message = "Sora 命中 Cloudflare challenge"
|
item.Message = "命中 Cloudflare challenge"
|
||||||
return item
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
|
|||||||
require.Contains(t, result.Summary, "挑战 1 项")
|
require.Contains(t, result.Summary, "挑战 1 项")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/html")
|
w.Header().Set("Content-Type", "text/html")
|
||||||
w.Header().Set("cf-ray", "test-ray-123")
|
w.Header().Set("cf-ray", "test-ray-123")
|
||||||
@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
target := proxyQualityTarget{
|
target := proxyQualityTarget{
|
||||||
Target: "sora",
|
Target: "openai",
|
||||||
URL: server.URL,
|
URL: server.URL,
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
AllowedStatuses: map[int]struct{}{
|
AllowedStatuses: map[int]struct{}{
|
||||||
|
|||||||
@ -5,13 +5,12 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
|
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
|
||||||
@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
|
|||||||
m.responseBodies[respIdx] = bodyBytes
|
m.responseBodies[respIdx] = bodyBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
// 用缓存的 body 字节重建新的 reader
|
// 用缓存的 body 重建 reader(支持重试场景多次读取)
|
||||||
var body io.ReadCloser
|
cloned := *resp
|
||||||
if m.responseBodies[respIdx] != nil {
|
if m.responseBodies[respIdx] != nil {
|
||||||
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
|
cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
|
||||||
}
|
}
|
||||||
|
return &cloned, respErr
|
||||||
return &http.Response{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
Header: resp.Header.Clone(),
|
|
||||||
Body: body,
|
|
||||||
}, respErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||||
|
|||||||
@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
|
|||||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
|
|||||||
@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
|
||||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
|
||||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
|
||||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
|
||||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
|
||||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
|
|||||||
@ -808,14 +808,6 @@ type ImagePriceConfig struct {
|
|||||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraPriceConfig Sora 按次计费配置
|
|
||||||
type SoraPriceConfig struct {
|
|
||||||
ImagePrice360 *float64
|
|
||||||
ImagePrice540 *float64
|
|
||||||
VideoPricePerRequest *float64
|
|
||||||
VideoPricePerRequestHD *float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalculateImageCost 计算图片生成费用
|
// CalculateImageCost 计算图片生成费用
|
||||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||||
@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateSoraImageCost 计算 Sora 图片按次费用
|
|
||||||
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
|
||||||
if imageCount <= 0 {
|
|
||||||
return &CostBreakdown{}
|
|
||||||
}
|
|
||||||
|
|
||||||
unitPrice := 0.0
|
|
||||||
if groupConfig != nil {
|
|
||||||
switch imageSize {
|
|
||||||
case "540":
|
|
||||||
if groupConfig.ImagePrice540 != nil {
|
|
||||||
unitPrice = *groupConfig.ImagePrice540
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if groupConfig.ImagePrice360 != nil {
|
|
||||||
unitPrice = *groupConfig.ImagePrice360
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
totalCost := unitPrice * float64(imageCount)
|
|
||||||
if rateMultiplier <= 0 {
|
|
||||||
rateMultiplier = 1.0
|
|
||||||
}
|
|
||||||
actualCost := totalCost * rateMultiplier
|
|
||||||
|
|
||||||
return &CostBreakdown{
|
|
||||||
TotalCost: totalCost,
|
|
||||||
ActualCost: actualCost,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalculateSoraVideoCost 计算 Sora 视频按次费用
|
|
||||||
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
|
||||||
unitPrice := 0.0
|
|
||||||
if groupConfig != nil {
|
|
||||||
modelLower := strings.ToLower(model)
|
|
||||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
|
||||||
if groupConfig.VideoPricePerRequestHD != nil {
|
|
||||||
unitPrice = *groupConfig.VideoPricePerRequestHD
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
|
|
||||||
unitPrice = *groupConfig.VideoPricePerRequest
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
totalCost := unitPrice
|
|
||||||
if rateMultiplier <= 0 {
|
|
||||||
rateMultiplier = 1.0
|
|
||||||
}
|
|
||||||
actualCost := totalCost * rateMultiplier
|
|
||||||
|
|
||||||
return &CostBreakdown{
|
|
||||||
TotalCost: totalCost,
|
|
||||||
ActualCost: actualCost,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getImageUnitPrice 获取图片单价
|
// getImageUnitPrice 获取图片单价
|
||||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||||
// 优先使用分组配置的价格
|
// 优先使用分组配置的价格
|
||||||
|
|||||||
@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
|
|||||||
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
|
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCalculateSoraVideoCost(t *testing.T) {
|
|
||||||
svc := newTestBillingService()
|
|
||||||
|
|
||||||
price := 0.5
|
|
||||||
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
|
|
||||||
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
|
|
||||||
|
|
||||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
|
|
||||||
svc := newTestBillingService()
|
|
||||||
|
|
||||||
hdPrice := 1.0
|
|
||||||
normalPrice := 0.5
|
|
||||||
cfg := &SoraPriceConfig{
|
|
||||||
VideoPricePerRequest: &normalPrice,
|
|
||||||
VideoPricePerRequestHD: &hdPrice,
|
|
||||||
}
|
|
||||||
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
|
|
||||||
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsModelSupported(t *testing.T) {
|
func TestIsModelSupported(t *testing.T) {
|
||||||
svc := newTestBillingService()
|
svc := newTestBillingService()
|
||||||
@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "not initialized")
|
require.Contains(t, err.Error(), "not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCalculateSoraImageCost(t *testing.T) {
|
|
||||||
svc := newTestBillingService()
|
|
||||||
|
|
||||||
price360 := 0.05
|
|
||||||
price540 := 0.08
|
|
||||||
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
|
|
||||||
|
|
||||||
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
|
|
||||||
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
|
|
||||||
|
|
||||||
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
|
|
||||||
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
|
|
||||||
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
|
|
||||||
svc := newTestBillingService()
|
|
||||||
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
|
|
||||||
require.Equal(t, 0.0, cost.TotalCost)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
|
|
||||||
svc := newTestBillingService()
|
|
||||||
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
|
|
||||||
require.Equal(t, 0.0, cost.TotalCost)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
|
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
|
||||||
// 使用空的 fallback prices 让 GetModelPricing 失败
|
// 使用空的 fallback prices 让 GetModelPricing 失败
|
||||||
svc := &BillingService{
|
svc := &BillingService{
|
||||||
|
|||||||
@ -24,7 +24,6 @@ const (
|
|||||||
PlatformOpenAI = domain.PlatformOpenAI
|
PlatformOpenAI = domain.PlatformOpenAI
|
||||||
PlatformGemini = domain.PlatformGemini
|
PlatformGemini = domain.PlatformGemini
|
||||||
PlatformAntigravity = domain.PlatformAntigravity
|
PlatformAntigravity = domain.PlatformAntigravity
|
||||||
PlatformSora = domain.PlatformSora
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
@ -107,7 +106,6 @@ const (
|
|||||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
|
|
||||||
SettingKeySiteName = "site_name" // 网站名称
|
SettingKeySiteName = "site_name" // 网站名称
|
||||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||||
@ -199,27 +197,6 @@ const (
|
|||||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||||
|
|
||||||
// =========================
|
|
||||||
// Sora S3 存储配置
|
|
||||||
// =========================
|
|
||||||
|
|
||||||
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
|
|
||||||
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
|
|
||||||
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
|
|
||||||
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
|
|
||||||
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
|
|
||||||
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
|
|
||||||
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
|
|
||||||
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
|
|
||||||
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
|
|
||||||
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
|
|
||||||
|
|
||||||
// =========================
|
|
||||||
// Sora 用户存储配额
|
|
||||||
// =========================
|
|
||||||
|
|
||||||
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
|
|
||||||
|
|
||||||
// =========================
|
// =========================
|
||||||
// Claude Code Version Check
|
// Claude Code Version Check
|
||||||
// =========================
|
// =========================
|
||||||
|
|||||||
@ -60,13 +60,6 @@ const (
|
|||||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MediaType 媒体类型常量
|
|
||||||
const (
|
|
||||||
MediaTypeImage = "image"
|
|
||||||
MediaTypeVideo = "video"
|
|
||||||
MediaTypePrompt = "prompt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||||
type forceCacheBillingKeyType struct{}
|
type forceCacheBillingKeyType struct{}
|
||||||
@ -511,9 +504,6 @@ type ForwardResult struct {
|
|||||||
ImageCount int // 生成的图片数量
|
ImageCount int // 生成的图片数量
|
||||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||||
|
|
||||||
// Sora 媒体字段
|
|
||||||
MediaType string // image / video / prompt
|
|
||||||
MediaURL string // 生成后的媒体地址(可选)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||||
@ -1971,9 +1961,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||||
if platform == PlatformSora {
|
|
||||||
return s.listSoraSchedulableAccounts(ctx, groupID)
|
|
||||||
}
|
|
||||||
if s.schedulerSnapshot != nil {
|
if s.schedulerSnapshot != nil {
|
||||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -2070,53 +2057,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
|||||||
return accounts, useMixed, nil
|
return accounts, useMixed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
|
|
||||||
const useMixed = false
|
|
||||||
|
|
||||||
var accounts []Account
|
|
||||||
var err error
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|
||||||
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
|
|
||||||
} else if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug("account_scheduling_list_failed",
|
|
||||||
"group_id", derefGroupID(groupID),
|
|
||||||
"platform", PlatformSora,
|
|
||||||
"error", err)
|
|
||||||
return nil, useMixed, err
|
|
||||||
}
|
|
||||||
|
|
||||||
filtered := make([]Account, 0, len(accounts))
|
|
||||||
for _, acc := range accounts {
|
|
||||||
if acc.Platform != PlatformSora {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !s.isSoraAccountSchedulable(&acc) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filtered = append(filtered, acc)
|
|
||||||
}
|
|
||||||
slog.Debug("account_scheduling_list_sora",
|
|
||||||
"group_id", derefGroupID(groupID),
|
|
||||||
"platform", PlatformSora,
|
|
||||||
"raw_count", len(accounts),
|
|
||||||
"filtered_count", len(filtered))
|
|
||||||
for _, acc := range filtered {
|
|
||||||
slog.Debug("account_scheduling_account_detail",
|
|
||||||
"account_id", acc.ID,
|
|
||||||
"name", acc.Name,
|
|
||||||
"platform", acc.Platform,
|
|
||||||
"type", acc.Type,
|
|
||||||
"status", acc.Status,
|
|
||||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
|
||||||
}
|
|
||||||
return filtered, useMixed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
|
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
|
||||||
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
|
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
|
||||||
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
|
||||||
@ -2141,33 +2081,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
|
|||||||
return account.Platform == platform
|
return account.Platform == platform
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
|
|
||||||
return s.soraUnschedulableReason(account) == ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
|
|
||||||
if account == nil {
|
|
||||||
return "account_nil"
|
|
||||||
}
|
|
||||||
if account.Status != StatusActive {
|
|
||||||
return fmt.Sprintf("status=%s", account.Status)
|
|
||||||
}
|
|
||||||
if !account.Schedulable {
|
|
||||||
return "schedulable=false"
|
|
||||||
}
|
|
||||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
|
||||||
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
|
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
return s.isSoraAccountSchedulable(account)
|
|
||||||
}
|
|
||||||
return account.IsSchedulable()
|
return account.IsSchedulable()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2175,12 +2092,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
if !s.isSoraAccountSchedulable(account) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
|
|
||||||
}
|
|
||||||
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3357,9 +3268,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
|
|||||||
stats.SampleMappingIDs,
|
stats.SampleMappingIDs,
|
||||||
stats.SampleRateLimitIDs,
|
stats.SampleRateLimitIDs,
|
||||||
)
|
)
|
||||||
if platform == PlatformSora {
|
|
||||||
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
|
|
||||||
}
|
|
||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3416,11 +3324,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "excluded"}
|
return selectionFailureDiagnosis{Category: "excluded"}
|
||||||
}
|
}
|
||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
detail := "generic_unschedulable"
|
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||||
if acc.Platform == PlatformSora {
|
|
||||||
detail = s.soraUnschedulableReason(acc)
|
|
||||||
}
|
|
||||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
|
||||||
}
|
}
|
||||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||||
return selectionFailureDiagnosis{
|
return selectionFailureDiagnosis{
|
||||||
@ -3444,57 +3348,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "eligible"}
|
return selectionFailureDiagnosis{Category: "eligible"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) logSoraSelectionFailureDetails(
|
|
||||||
ctx context.Context,
|
|
||||||
groupID *int64,
|
|
||||||
sessionHash string,
|
|
||||||
requestedModel string,
|
|
||||||
accounts []Account,
|
|
||||||
excludedIDs map[int64]struct{},
|
|
||||||
allowMixedScheduling bool,
|
|
||||||
) {
|
|
||||||
const maxLines = 30
|
|
||||||
logged := 0
|
|
||||||
|
|
||||||
for i := range accounts {
|
|
||||||
if logged >= maxLines {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
acc := &accounts[i]
|
|
||||||
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
|
|
||||||
if diagnosis.Category == "eligible" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
detail := diagnosis.Detail
|
|
||||||
if detail == "" {
|
|
||||||
detail = "-"
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"service.gateway",
|
|
||||||
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
|
|
||||||
derefGroupID(groupID),
|
|
||||||
requestedModel,
|
|
||||||
shortSessionHash(sessionHash),
|
|
||||||
acc.ID,
|
|
||||||
acc.Platform,
|
|
||||||
diagnosis.Category,
|
|
||||||
detail,
|
|
||||||
)
|
|
||||||
logged++
|
|
||||||
}
|
|
||||||
if len(accounts) > maxLines {
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"service.gateway",
|
|
||||||
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
|
|
||||||
derefGroupID(groupID),
|
|
||||||
requestedModel,
|
|
||||||
shortSessionHash(sessionHash),
|
|
||||||
len(accounts),
|
|
||||||
logged,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||||
if acc == nil {
|
if acc == nil {
|
||||||
return true
|
return true
|
||||||
@ -3573,9 +3426,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
}
|
}
|
||||||
return mapAntigravityModel(account, requestedModel) != ""
|
return mapAntigravityModel(account, requestedModel) != ""
|
||||||
}
|
}
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
return s.isSoraModelSupportedByAccount(account, requestedModel)
|
|
||||||
}
|
|
||||||
if account.IsBedrock() {
|
if account.IsBedrock() {
|
||||||
_, ok := ResolveBedrockModelID(account, requestedModel)
|
_, ok := ResolveBedrockModelID(account, requestedModel)
|
||||||
return ok
|
return ok
|
||||||
@ -3588,143 +3438,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
return account.IsModelSupported(requestedModel)
|
return account.IsModelSupported(requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
|
|
||||||
if account == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(requestedModel) == "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先走原始精确/通配符匹配。
|
|
||||||
mapping := account.GetModelMapping()
|
|
||||||
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
aliases := buildSoraModelAliases(requestedModel)
|
|
||||||
if len(aliases) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
hasSoraSelector := false
|
|
||||||
for pattern := range mapping {
|
|
||||||
if !isSoraModelSelector(pattern) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
hasSoraSelector = true
|
|
||||||
if matchPatternAnyAlias(pattern, aliases) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
|
|
||||||
// 此时不应误拦截 Sora 模型请求。
|
|
||||||
if !hasSoraSelector {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func matchPatternAnyAlias(pattern string, aliases []string) bool {
|
|
||||||
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
|
|
||||||
if normalizedPattern == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, alias := range aliases {
|
|
||||||
if matchWildcard(normalizedPattern, alias) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSoraModelSelector(pattern string) bool {
|
|
||||||
p := strings.ToLower(strings.TrimSpace(pattern))
|
|
||||||
if p == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(p, "sora"),
|
|
||||||
strings.HasPrefix(p, "gpt-image"),
|
|
||||||
strings.HasPrefix(p, "prompt-enhance"),
|
|
||||||
strings.HasPrefix(p, "sy_"):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return p == "video" || p == "image"
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildSoraModelAliases(requestedModel string) []string {
|
|
||||||
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
|
|
||||||
if modelID == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
aliases := make([]string, 0, 8)
|
|
||||||
addAlias := func(value string) {
|
|
||||||
v := strings.ToLower(strings.TrimSpace(value))
|
|
||||||
if v == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, existing := range aliases {
|
|
||||||
if existing == v {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
aliases = append(aliases, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
addAlias(modelID)
|
|
||||||
cfg, ok := GetSoraModelConfig(modelID)
|
|
||||||
if ok {
|
|
||||||
addAlias(cfg.Model)
|
|
||||||
switch cfg.Type {
|
|
||||||
case "video":
|
|
||||||
addAlias("video")
|
|
||||||
addAlias("sora")
|
|
||||||
addAlias(soraVideoFamilyAlias(modelID))
|
|
||||||
case "image":
|
|
||||||
addAlias("image")
|
|
||||||
addAlias("gpt-image")
|
|
||||||
case "prompt_enhance":
|
|
||||||
addAlias("prompt-enhance")
|
|
||||||
}
|
|
||||||
return aliases
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(modelID, "sora"):
|
|
||||||
addAlias("video")
|
|
||||||
addAlias("sora")
|
|
||||||
addAlias(soraVideoFamilyAlias(modelID))
|
|
||||||
case strings.HasPrefix(modelID, "gpt-image"):
|
|
||||||
addAlias("image")
|
|
||||||
addAlias("gpt-image")
|
|
||||||
case strings.HasPrefix(modelID, "prompt-enhance"):
|
|
||||||
addAlias("prompt-enhance")
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return aliases
|
|
||||||
}
|
|
||||||
|
|
||||||
func soraVideoFamilyAlias(modelID string) string {
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(modelID, "sora2pro-hd"):
|
|
||||||
return "sora2pro-hd"
|
|
||||||
case strings.HasPrefix(modelID, "sora2pro"):
|
|
||||||
return "sora2pro"
|
|
||||||
case strings.HasPrefix(modelID, "sora2"):
|
|
||||||
return "sora2"
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccessToken 获取账号凭证
|
// GetAccessToken 获取账号凭证
|
||||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||||
switch account.Type {
|
switch account.Type {
|
||||||
@ -7592,9 +7305,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
|||||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||||
cmd.ImageCount = usageLog.ImageCount
|
cmd.ImageCount = usageLog.ImageCount
|
||||||
if usageLog.MediaType != nil {
|
|
||||||
cmd.MediaType = *usageLog.MediaType
|
|
||||||
}
|
|
||||||
if usageLog.ServiceTier != nil {
|
if usageLog.ServiceTier != nil {
|
||||||
cmd.ServiceTier = *usageLog.ServiceTier
|
cmd.ServiceTier = *usageLog.ServiceTier
|
||||||
}
|
}
|
||||||
@ -7750,8 +7460,6 @@ type recordUsageOpts struct {
|
|||||||
|
|
||||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||||
// - Claude Max 缓存计费策略
|
// - Claude Max 缓存计费策略
|
||||||
// - Sora 媒体类型分支(image/video/prompt)
|
|
||||||
// - MediaType 字段写入使用日志
|
|
||||||
EnableClaudePath bool
|
EnableClaudePath bool
|
||||||
|
|
||||||
// 长上下文计费(仅 Gemini 路径需要)
|
// 长上下文计费(仅 Gemini 路径需要)
|
||||||
@ -7842,7 +7550,6 @@ type recordUsageCoreInput struct {
|
|||||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||||
// opts 中的字段控制两者之间的差异行为:
|
// opts 中的字段控制两者之间的差异行为:
|
||||||
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||||
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
|
|
||||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
@ -7944,16 +7651,6 @@ func (s *GatewayService) calculateRecordUsageCost(
|
|||||||
multiplier float64,
|
multiplier float64,
|
||||||
opts *recordUsageOpts,
|
opts *recordUsageOpts,
|
||||||
) *CostBreakdown {
|
) *CostBreakdown {
|
||||||
// Sora 媒体类型分支(仅 Claude 路径启用)
|
|
||||||
if opts.EnableClaudePath {
|
|
||||||
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
|
||||||
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
|
|
||||||
}
|
|
||||||
if result.MediaType == MediaTypePrompt {
|
|
||||||
return &CostBreakdown{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 图片生成计费
|
// 图片生成计费
|
||||||
if result.ImageCount > 0 {
|
if result.ImageCount > 0 {
|
||||||
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||||
@ -7963,28 +7660,6 @@ func (s *GatewayService) calculateRecordUsageCost(
|
|||||||
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
|
|
||||||
func (s *GatewayService) calculateSoraMediaCost(
|
|
||||||
result *ForwardResult,
|
|
||||||
apiKey *APIKey,
|
|
||||||
billingModel string,
|
|
||||||
multiplier float64,
|
|
||||||
) *CostBreakdown {
|
|
||||||
var soraConfig *SoraPriceConfig
|
|
||||||
if apiKey.Group != nil {
|
|
||||||
soraConfig = &SoraPriceConfig{
|
|
||||||
ImagePrice360: apiKey.Group.SoraImagePrice360,
|
|
||||||
ImagePrice540: apiKey.Group.SoraImagePrice540,
|
|
||||||
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
|
||||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if result.MediaType == MediaTypeImage {
|
|
||||||
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
|
||||||
}
|
|
||||||
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||||
@ -8133,13 +7808,12 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
AccountRateMultiplier: &accountRateMultiplier,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
BillingMode: resolveBillingMode(opts, result, cost),
|
BillingMode: resolveBillingMode(result, cost),
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||||
MediaType: resolveMediaType(opts, result),
|
|
||||||
CacheTTLOverridden: cacheTTLOverridden,
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||||
@ -8163,13 +7837,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||||
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
|
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
|
||||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
|
||||||
isSoraMedia := opts.EnableClaudePath &&
|
|
||||||
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
|
||||||
if isSoraMedia {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var mode string
|
var mode string
|
||||||
switch {
|
switch {
|
||||||
case cost != nil && cost.BillingMode != "":
|
case cost != nil && cost.BillingMode != "":
|
||||||
@ -8182,13 +7850,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
|||||||
return &mode
|
return &mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
|
||||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
|
||||||
return &result.MediaType
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||||
if subscription != nil {
|
if subscription != nil {
|
||||||
return &subscription.ID
|
return &subscription.ID
|
||||||
|
|||||||
@ -9,35 +9,35 @@ import (
|
|||||||
|
|
||||||
func TestCollectSelectionFailureStats(t *testing.T) {
|
func TestCollectSelectionFailureStats(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
model := "sora2-landscape-10s"
|
model := "gpt-5.4"
|
||||||
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
|
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
|
||||||
|
|
||||||
accounts := []Account{
|
accounts := []Account{
|
||||||
// excluded
|
// excluded
|
||||||
{
|
{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
},
|
},
|
||||||
// unschedulable
|
// unschedulable
|
||||||
{
|
{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: false,
|
Schedulable: false,
|
||||||
},
|
},
|
||||||
// platform filtered
|
// platform filtered
|
||||||
{
|
{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformAntigravity,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
},
|
},
|
||||||
// model unsupported
|
// model unsupported
|
||||||
{
|
{
|
||||||
ID: 4,
|
ID: 4,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
|||||||
// model rate limited
|
// model rate limited
|
||||||
{
|
{
|
||||||
ID: 5,
|
ID: 5,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
|||||||
// eligible
|
// eligible
|
||||||
{
|
{
|
||||||
ID: 6,
|
ID: 6,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
excluded := map[int64]struct{}{1: {}}
|
excluded := map[int64]struct{}{1: {}}
|
||||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
|
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false)
|
||||||
|
|
||||||
if stats.Total != 6 {
|
if stats.Total != 6 {
|
||||||
t.Fatalf("total=%d want=6", stats.Total)
|
t.Fatalf("total=%d want=6", stats.Total)
|
||||||
@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
|
func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
acc := &Account{
|
acc := &Account{
|
||||||
ID: 7,
|
ID: 7,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: false,
|
Schedulable: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false)
|
||||||
if diagnosis.Category != "unschedulable" {
|
if diagnosis.Category != "unschedulable" {
|
||||||
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
|
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
|
||||||
}
|
}
|
||||||
if diagnosis.Detail != "schedulable=false" {
|
if diagnosis.Detail != "generic_unschedulable" {
|
||||||
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
|
t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) {
|
||||||
svc := &GatewayService{}
|
svc := &GatewayService{}
|
||||||
model := "sora2-landscape-10s"
|
model := "gpt-5.4"
|
||||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||||
acc := &Account{
|
acc := &Account{
|
||||||
ID: 8,
|
ID: 8,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformOpenAI,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
|
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false)
|
||||||
if diagnosis.Category != "model_rate_limited" {
|
if diagnosis.Category != "model_rate_limited" {
|
||||||
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
|
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,79 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
account := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Credentials: map[string]any{},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
|
||||||
t.Fatalf("expected sora model to be supported when model_mapping is empty")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
account := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"model_mapping": map[string]any{
|
|
||||||
"gpt-4o": "gpt-4o",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
|
||||||
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
account := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"model_mapping": map[string]any{
|
|
||||||
"sora2": "sora2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
|
|
||||||
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
account := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"model_mapping": map[string]any{
|
|
||||||
"sy_8": "sy_8",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
|
||||||
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
account := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"model_mapping": map[string]any{
|
|
||||||
"gpt-image": "gpt-image",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
|
||||||
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
now := time.Now()
|
|
||||||
past := now.Add(-1 * time.Minute)
|
|
||||||
future := now.Add(5 * time.Minute)
|
|
||||||
|
|
||||||
acc := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
AutoPauseOnExpired: true,
|
|
||||||
ExpiresAt: &past,
|
|
||||||
OverloadUntil: &future,
|
|
||||||
RateLimitResetAt: &future,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !svc.isAccountSchedulableForSelection(acc) {
|
|
||||||
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
future := time.Now().Add(5 * time.Minute)
|
|
||||||
|
|
||||||
acc := &Account{
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
RateLimitResetAt: &future,
|
|
||||||
}
|
|
||||||
|
|
||||||
if svc.isAccountSchedulableForSelection(acc) {
|
|
||||||
t.Fatalf("expected non-sora account to keep generic schedulable checks")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
model := "sora2-landscape-10s"
|
|
||||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
|
||||||
globalResetAt := time.Now().Add(2 * time.Minute)
|
|
||||||
|
|
||||||
acc := &Account{
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
RateLimitResetAt: &globalResetAt,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"model_rate_limits": map[string]any{
|
|
||||||
model: map[string]any{
|
|
||||||
"rate_limit_reset_at": resetAt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
|
|
||||||
t.Fatalf("expected sora account to be blocked by model scope rate limit")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
|
|
||||||
svc := &GatewayService{}
|
|
||||||
future := time.Now().Add(3 * time.Minute)
|
|
||||||
|
|
||||||
accounts := []Account{
|
|
||||||
{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
RateLimitResetAt: &future,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
|
||||||
if stats.Unschedulable != 0 || stats.Eligible != 1 {
|
|
||||||
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -26,15 +26,6 @@ type Group struct {
|
|||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
|
|
||||||
// Sora 按次计费配置(阶段 1)
|
|
||||||
SoraImagePrice360 *float64
|
|
||||||
SoraImagePrice540 *float64
|
|
||||||
SoraVideoPricePerRequest *float64
|
|
||||||
SoraVideoPricePerRequestHD *float64
|
|
||||||
|
|
||||||
// Sora 存储配额
|
|
||||||
SoraStorageQuotaBytes int64
|
|
||||||
|
|
||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool
|
ClaudeCodeOnly bool
|
||||||
FallbackGroupID *int64
|
FallbackGroupID *int64
|
||||||
@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
|
|
||||||
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
|
|
||||||
switch imageSize {
|
|
||||||
case "360":
|
|
||||||
return g.SoraImagePrice360
|
|
||||||
case "540":
|
|
||||||
return g.SoraImagePrice540
|
|
||||||
default:
|
|
||||||
return g.SoraImagePrice360
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||||
func IsGroupContextValid(group *Group) bool {
|
func IsGroupContextValid(group *Group) bool {
|
||||||
if group == nil {
|
if group == nil {
|
||||||
|
|||||||
@ -3,30 +3,15 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
|
||||||
|
|
||||||
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
|
|
||||||
|
|
||||||
type soraSessionChunk struct {
|
|
||||||
index int
|
|
||||||
value string
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||||
type OpenAIOAuthService struct {
|
type OpenAIOAuthService struct {
|
||||||
sessionStore *openai.SessionStore
|
sessionStore *openai.SessionStore
|
||||||
@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
|||||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
|
||||||
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||||
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
|
|||||||
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
|
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
// RefreshAccountToken refreshes token for an OpenAI OAuth account
|
||||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
|
||||||
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
|
|
||||||
if strings.TrimSpace(sessionToken) == "" {
|
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
|
||||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
|
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
|
||||||
ProxyURL: proxyURL,
|
|
||||||
Timeout: 120 * time.Second,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
|
|
||||||
}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var sessionResp struct {
|
|
||||||
AccessToken string `json:"accessToken"`
|
|
||||||
Expires string `json:"expires"`
|
|
||||||
User struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
} `json:"user"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
|
||||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
|
||||||
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
|
||||||
}
|
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
|
||||||
if strings.TrimSpace(sessionResp.Expires) != "" {
|
|
||||||
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
|
||||||
expiresAt = parsed.Unix()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expiresIn := expiresAt - time.Now().Unix()
|
|
||||||
if expiresIn < 0 {
|
|
||||||
expiresIn = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return &OpenAITokenInfo{
|
|
||||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
|
||||||
ExpiresIn: expiresIn,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
ClientID: openai.SoraClientID,
|
|
||||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeSoraSessionTokenInput(raw string) string {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return sanitizeSessionToken(trimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
chunkMatches := make([]soraSessionChunk, 0, len(matches))
|
|
||||||
singleValues := make([]string, 0, len(matches))
|
|
||||||
|
|
||||||
for _, match := range matches {
|
|
||||||
if len(match) < 3 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
value := sanitizeSessionToken(match[2])
|
|
||||||
if value == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(match[1]) == "" {
|
|
||||||
singleValues = append(singleValues, value)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
|
|
||||||
if err != nil || idx < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
chunkMatches = append(chunkMatches, soraSessionChunk{
|
|
||||||
index: idx,
|
|
||||||
value: value,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
|
|
||||||
return merged
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(singleValues) > 0 {
|
|
||||||
return singleValues[len(singleValues)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
|
|
||||||
if len(chunks) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
byIndex := make(map[int]string, len(chunks))
|
|
||||||
for _, chunk := range chunks {
|
|
||||||
byIndex[chunk.index] = chunk.value
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := byIndex[0]; !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if requireComplete {
|
|
||||||
for idx := 0; idx <= requiredMaxIndex; idx++ {
|
|
||||||
if _, ok := byIndex[idx]; !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
orderedIndexes := make([]int, 0, len(byIndex))
|
|
||||||
for idx := range byIndex {
|
|
||||||
orderedIndexes = append(orderedIndexes, idx)
|
|
||||||
}
|
|
||||||
sort.Ints(orderedIndexes)
|
|
||||||
|
|
||||||
var builder strings.Builder
|
|
||||||
for _, idx := range orderedIndexes {
|
|
||||||
if _, err := builder.WriteString(byIndex[idx]); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sanitizeSessionToken(builder.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
|
|
||||||
if len(chunks) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
requiredMaxIndex := 0
|
|
||||||
for _, chunk := range chunks {
|
|
||||||
if chunk.index > requiredMaxIndex {
|
|
||||||
requiredMaxIndex = chunk.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groupStarts := make([]int, 0, len(chunks))
|
|
||||||
for idx, chunk := range chunks {
|
|
||||||
if chunk.index == 0 {
|
|
||||||
groupStarts = append(groupStarts, idx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(groupStarts) == 0 {
|
|
||||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := len(groupStarts) - 1; i >= 0; i-- {
|
|
||||||
start := groupStarts[i]
|
|
||||||
end := len(chunks)
|
|
||||||
if i+1 < len(groupStarts) {
|
|
||||||
end = groupStarts[i+1]
|
|
||||||
}
|
|
||||||
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
|
|
||||||
return merged
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeSessionToken(raw string) string {
|
|
||||||
token := strings.TrimSpace(raw)
|
|
||||||
token = strings.Trim(token, "\"'`")
|
|
||||||
token = strings.TrimSuffix(token, ";")
|
|
||||||
return strings.TrimSpace(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
|
||||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
if account.Platform != PlatformOpenAI {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||||
}
|
}
|
||||||
if account.Type != AccountTypeOAuth {
|
if account.Type != AccountTypeOAuth {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||||
@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() {
|
|||||||
s.sessionStore.Stop()
|
s.sessionStore.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
|
||||||
if proxyID == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
|
||||||
if err != nil {
|
|
||||||
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
|
||||||
}
|
|
||||||
if proxy == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return proxy.URL(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
return openai.OAuthPlatformOpenAI
|
||||||
case PlatformSora:
|
|
||||||
return openai.OAuthPlatformSora
|
|
||||||
default:
|
|
||||||
return openai.OAuthPlatformOpenAI
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, openai.ClientID, session.ClientID)
|
require.Equal(t, openai.ClientID, session.ClientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
|
|
||||||
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
|
|
||||||
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, result.AuthURL)
|
|
||||||
require.NotEmpty(t, result.SessionID)
|
|
||||||
|
|
||||||
parsed, err := url.Parse(result.AuthURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
q := parsed.Query()
|
|
||||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
|
||||||
require.Empty(t, q.Get("codex_cli_simplified_flow"))
|
|
||||||
|
|
||||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, openai.ClientID, session.ClientID)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,173 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type openaiOAuthClientNoopStub struct{}
|
|
||||||
|
|
||||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodGet, r.Method)
|
|
||||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, info)
|
|
||||||
require.Equal(t, "at-token", info.AccessToken)
|
|
||||||
require.Equal(t, "demo@example.com", info.Email)
|
|
||||||
require.Greater(t, info.ExpiresAt, int64(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "missing access token")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodGet, r.Method)
|
|
||||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
|
|
||||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "at-token", info.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodGet, r.Method)
|
|
||||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
raw := strings.Join([]string{
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
|
|
||||||
}, "\n")
|
|
||||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "at-token", info.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodGet, r.Method)
|
|
||||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
raw := strings.Join([]string{
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
|
|
||||||
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
|
|
||||||
}, "\n")
|
|
||||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "at-token", info.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodGet, r.Method)
|
|
||||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
origin := openAISoraSessionAuthURL
|
|
||||||
openAISoraSessionAuthURL = server.URL
|
|
||||||
defer func() { openAISoraSessionAuthURL = origin }()
|
|
||||||
|
|
||||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
|
||||||
defer svc.Stop()
|
|
||||||
|
|
||||||
raw := strings.Join([]string{
|
|
||||||
"set-cookie",
|
|
||||||
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
|
|
||||||
"set-cookie",
|
|
||||||
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
|
|
||||||
"set-cookie",
|
|
||||||
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
|
|
||||||
}, "\n")
|
|
||||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "at-token", info.AccessToken)
|
|
||||||
}
|
|
||||||
@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
|||||||
// OpenAITokenCache token cache interface.
|
// OpenAITokenCache token cache interface.
|
||||||
type OpenAITokenCache = GeminiTokenCache
|
type OpenAITokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
|
// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
|
||||||
type OpenAITokenProvider struct {
|
type OpenAITokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache OpenAITokenCache
|
tokenCache OpenAITokenCache
|
||||||
@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
|
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||||
return "", errors.New("not an openai/sora oauth account")
|
return "", errors.New("not an openai oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := OpenAITokenCacheKey(account)
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
p.metrics.refreshRequests.Add(1)
|
p.metrics.refreshRequests.Add(1)
|
||||||
p.metrics.touchNow()
|
p.metrics.touchNow()
|
||||||
|
|
||||||
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
||||||
if account.Platform == PlatformSora {
|
if err != nil {
|
||||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
refreshFailed = true
|
return "", err
|
||||||
} else {
|
|
||||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
|
||||||
if err != nil {
|
|
||||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
|
||||||
p.metrics.refreshFailure.Add(1)
|
|
||||||
refreshFailed = true
|
|
||||||
} else if result.LockHeld {
|
|
||||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
|
||||||
p.metrics.lockContention.Add(1)
|
|
||||||
p.metrics.touchNow()
|
|
||||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
|
||||||
if waitErr != nil {
|
|
||||||
return "", waitErr
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(token) != "" {
|
|
||||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if result.Refreshed {
|
|
||||||
p.metrics.refreshSuccess.Add(1)
|
|
||||||
account = result.Account
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
} else {
|
|
||||||
account = result.Account
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
}
|
||||||
|
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
|
refreshFailed = true
|
||||||
|
} else if result.LockHeld {
|
||||||
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
||||||
|
p.metrics.lockContention.Add(1)
|
||||||
|
p.metrics.touchNow()
|
||||||
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||||
|
if waitErr != nil {
|
||||||
|
return "", waitErr
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if result.Refreshed {
|
||||||
|
p.metrics.refreshSuccess.Add(1)
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
}
|
}
|
||||||
} else if needsRefresh && p.tokenCache != nil {
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
// Backward-compatible test path when refreshAPI is not injected.
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
|
|||||||
@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
|||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
|||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -22,8 +22,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
|
||||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
|
||||||
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
|
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
|
||||||
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
|
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
|
||||||
"default subscription group must exist and be subscription type",
|
"default subscription group must exist and be subscription type",
|
||||||
@ -104,7 +102,6 @@ type SettingService struct {
|
|||||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||||
onS3Update func() // Callback when Sora S3 settings are updated
|
|
||||||
version string // Application version
|
version string // Application version
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
SettingKeyHideCcsImportButton,
|
SettingKeyHideCcsImportButton,
|
||||||
SettingKeyPurchaseSubscriptionEnabled,
|
SettingKeyPurchaseSubscriptionEnabled,
|
||||||
SettingKeyPurchaseSubscriptionURL,
|
SettingKeyPurchaseSubscriptionURL,
|
||||||
SettingKeySoraClientEnabled,
|
|
||||||
SettingKeyCustomMenuItems,
|
SettingKeyCustomMenuItems,
|
||||||
SettingKeyCustomEndpoints,
|
SettingKeyCustomEndpoints,
|
||||||
SettingKeyLinuxDoConnectEnabled,
|
SettingKeyLinuxDoConnectEnabled,
|
||||||
@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
|
||||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||||
@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
|||||||
s.onUpdate = callback
|
s.onUpdate = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
|
|
||||||
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
|
|
||||||
s.onS3Update = callback
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVersion sets the application version for injection into public settings
|
// SetVersion sets the application version for injection into public settings
|
||||||
func (s *SettingService) SetVersion(version string) {
|
func (s *SettingService) SetVersion(version string) {
|
||||||
s.version = version
|
s.version = version
|
||||||
@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
|
||||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
|||||||
HideCcsImportButton: settings.HideCcsImportButton,
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
|
||||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
|
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||||
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
|
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
|
||||||
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
||||||
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
|
|
||||||
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
||||||
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
|
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
|
||||||
|
|
||||||
@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeySiteLogo: "",
|
SettingKeySiteLogo: "",
|
||||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||||
SettingKeyPurchaseSubscriptionURL: "",
|
SettingKeyPurchaseSubscriptionURL: "",
|
||||||
SettingKeySoraClientEnabled: "false",
|
|
||||||
SettingKeyCustomMenuItems: "[]",
|
SettingKeyCustomMenuItems: "[]",
|
||||||
SettingKeyCustomEndpoints: "[]",
|
SettingKeyCustomEndpoints: "[]",
|
||||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||||
@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
|
||||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||||
@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
|
|||||||
|
|
||||||
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
|
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
type soraS3ProfilesStore struct {
|
|
||||||
ActiveProfileID string `json:"active_profile_id"`
|
|
||||||
Items []soraS3ProfileStoreItem `json:"items"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type soraS3ProfileStoreItem struct {
|
|
||||||
ProfileID string `json:"profile_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"secret_access_key"`
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
UpdatedAt string `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
|
|
||||||
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
|
||||||
profiles, err := s.ListSoraS3Profiles(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
|
||||||
if activeProfile == nil {
|
|
||||||
return &SoraS3Settings{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &SoraS3Settings{
|
|
||||||
Enabled: activeProfile.Enabled,
|
|
||||||
Endpoint: activeProfile.Endpoint,
|
|
||||||
Region: activeProfile.Region,
|
|
||||||
Bucket: activeProfile.Bucket,
|
|
||||||
AccessKeyID: activeProfile.AccessKeyID,
|
|
||||||
SecretAccessKey: activeProfile.SecretAccessKey,
|
|
||||||
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
|
|
||||||
Prefix: activeProfile.Prefix,
|
|
||||||
ForcePathStyle: activeProfile.ForcePathStyle,
|
|
||||||
CDNURL: activeProfile.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
|
|
||||||
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
|
|
||||||
if settings == nil {
|
|
||||||
return fmt.Errorf("settings cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
|
|
||||||
if activeIndex < 0 {
|
|
||||||
activeID := "default"
|
|
||||||
if hasSoraS3ProfileID(store.Items, activeID) {
|
|
||||||
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
|
|
||||||
}
|
|
||||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
|
||||||
ProfileID: activeID,
|
|
||||||
Name: "Default",
|
|
||||||
UpdatedAt: now,
|
|
||||||
})
|
|
||||||
store.ActiveProfileID = activeID
|
|
||||||
activeIndex = len(store.Items) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
active := store.Items[activeIndex]
|
|
||||||
active.Enabled = settings.Enabled
|
|
||||||
active.Endpoint = strings.TrimSpace(settings.Endpoint)
|
|
||||||
active.Region = strings.TrimSpace(settings.Region)
|
|
||||||
active.Bucket = strings.TrimSpace(settings.Bucket)
|
|
||||||
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
|
|
||||||
active.Prefix = strings.TrimSpace(settings.Prefix)
|
|
||||||
active.ForcePathStyle = settings.ForcePathStyle
|
|
||||||
active.CDNURL = strings.TrimSpace(settings.CDNURL)
|
|
||||||
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
|
|
||||||
if settings.SecretAccessKey != "" {
|
|
||||||
active.SecretAccessKey = settings.SecretAccessKey
|
|
||||||
}
|
|
||||||
active.UpdatedAt = now
|
|
||||||
store.Items[activeIndex] = active
|
|
||||||
|
|
||||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListSoraS3Profiles 获取 Sora S3 多配置列表
|
|
||||||
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return convertSoraS3ProfilesStore(store), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
|
||||||
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
|
|
||||||
if profile == nil {
|
|
||||||
return nil, fmt.Errorf("profile cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
profileID := strings.TrimSpace(profile.ProfileID)
|
|
||||||
if profileID == "" {
|
|
||||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
|
||||||
}
|
|
||||||
name := strings.TrimSpace(profile.Name)
|
|
||||||
if name == "" {
|
|
||||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if hasSoraS3ProfileID(store.Items, profileID) {
|
|
||||||
return nil, ErrSoraS3ProfileExists
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
|
||||||
ProfileID: profileID,
|
|
||||||
Name: name,
|
|
||||||
Enabled: profile.Enabled,
|
|
||||||
Endpoint: strings.TrimSpace(profile.Endpoint),
|
|
||||||
Region: strings.TrimSpace(profile.Region),
|
|
||||||
Bucket: strings.TrimSpace(profile.Bucket),
|
|
||||||
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
|
|
||||||
SecretAccessKey: profile.SecretAccessKey,
|
|
||||||
Prefix: strings.TrimSpace(profile.Prefix),
|
|
||||||
ForcePathStyle: profile.ForcePathStyle,
|
|
||||||
CDNURL: strings.TrimSpace(profile.CDNURL),
|
|
||||||
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
|
|
||||||
UpdatedAt: now,
|
|
||||||
})
|
|
||||||
|
|
||||||
if setActive || store.ActiveProfileID == "" {
|
|
||||||
store.ActiveProfileID = profileID
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := convertSoraS3ProfilesStore(store)
|
|
||||||
created := findSoraS3ProfileByID(profiles.Items, profileID)
|
|
||||||
if created == nil {
|
|
||||||
return nil, ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
return created, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
|
||||||
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
|
|
||||||
if profile == nil {
|
|
||||||
return nil, fmt.Errorf("profile cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
targetID := strings.TrimSpace(profileID)
|
|
||||||
if targetID == "" {
|
|
||||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
|
||||||
if targetIndex < 0 {
|
|
||||||
return nil, ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
target := store.Items[targetIndex]
|
|
||||||
name := strings.TrimSpace(profile.Name)
|
|
||||||
if name == "" {
|
|
||||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
|
||||||
}
|
|
||||||
target.Name = name
|
|
||||||
target.Enabled = profile.Enabled
|
|
||||||
target.Endpoint = strings.TrimSpace(profile.Endpoint)
|
|
||||||
target.Region = strings.TrimSpace(profile.Region)
|
|
||||||
target.Bucket = strings.TrimSpace(profile.Bucket)
|
|
||||||
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
|
|
||||||
target.Prefix = strings.TrimSpace(profile.Prefix)
|
|
||||||
target.ForcePathStyle = profile.ForcePathStyle
|
|
||||||
target.CDNURL = strings.TrimSpace(profile.CDNURL)
|
|
||||||
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
|
|
||||||
if profile.SecretAccessKey != "" {
|
|
||||||
target.SecretAccessKey = profile.SecretAccessKey
|
|
||||||
}
|
|
||||||
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
|
||||||
store.Items[targetIndex] = target
|
|
||||||
|
|
||||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := convertSoraS3ProfilesStore(store)
|
|
||||||
updated := findSoraS3ProfileByID(profiles.Items, targetID)
|
|
||||||
if updated == nil {
|
|
||||||
return nil, ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
return updated, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
|
||||||
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
|
|
||||||
targetID := strings.TrimSpace(profileID)
|
|
||||||
if targetID == "" {
|
|
||||||
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
|
||||||
if targetIndex < 0 {
|
|
||||||
return ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
|
|
||||||
if store.ActiveProfileID == targetID {
|
|
||||||
store.ActiveProfileID = ""
|
|
||||||
if len(store.Items) > 0 {
|
|
||||||
store.ActiveProfileID = store.Items[0].ProfileID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
|
|
||||||
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
|
|
||||||
targetID := strings.TrimSpace(profileID)
|
|
||||||
if targetID == "" {
|
|
||||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
|
||||||
if targetIndex < 0 {
|
|
||||||
return nil, ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
store.ActiveProfileID = targetID
|
|
||||||
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
|
||||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := convertSoraS3ProfilesStore(store)
|
|
||||||
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
|
||||||
if active == nil {
|
|
||||||
return nil, ErrSoraS3ProfileNotFound
|
|
||||||
}
|
|
||||||
return active, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
|
|
||||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
|
|
||||||
if err == nil {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if trimmed == "" {
|
|
||||||
return &soraS3ProfilesStore{}, nil
|
|
||||||
}
|
|
||||||
var store soraS3ProfilesStore
|
|
||||||
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
|
|
||||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
|
||||||
if legacyErr != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
|
|
||||||
}
|
|
||||||
if isEmptyLegacySoraS3Settings(legacy) {
|
|
||||||
return &soraS3ProfilesStore{}, nil
|
|
||||||
}
|
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
return &soraS3ProfilesStore{
|
|
||||||
ActiveProfileID: "default",
|
|
||||||
Items: []soraS3ProfileStoreItem{
|
|
||||||
{
|
|
||||||
ProfileID: "default",
|
|
||||||
Name: "Default",
|
|
||||||
Enabled: legacy.Enabled,
|
|
||||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
|
||||||
Region: strings.TrimSpace(legacy.Region),
|
|
||||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
|
||||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
|
||||||
SecretAccessKey: legacy.SecretAccessKey,
|
|
||||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
|
||||||
ForcePathStyle: legacy.ForcePathStyle,
|
|
||||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
|
||||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
|
||||||
UpdatedAt: now,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
normalized := normalizeSoraS3ProfilesStore(store)
|
|
||||||
return &normalized, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !errors.Is(err, ErrSettingNotFound) {
|
|
||||||
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
|
||||||
if legacyErr != nil {
|
|
||||||
return nil, legacyErr
|
|
||||||
}
|
|
||||||
if isEmptyLegacySoraS3Settings(legacy) {
|
|
||||||
return &soraS3ProfilesStore{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
return &soraS3ProfilesStore{
|
|
||||||
ActiveProfileID: "default",
|
|
||||||
Items: []soraS3ProfileStoreItem{
|
|
||||||
{
|
|
||||||
ProfileID: "default",
|
|
||||||
Name: "Default",
|
|
||||||
Enabled: legacy.Enabled,
|
|
||||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
|
||||||
Region: strings.TrimSpace(legacy.Region),
|
|
||||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
|
||||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
|
||||||
SecretAccessKey: legacy.SecretAccessKey,
|
|
||||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
|
||||||
ForcePathStyle: legacy.ForcePathStyle,
|
|
||||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
|
||||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
|
||||||
UpdatedAt: now,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
|
|
||||||
if store == nil {
|
|
||||||
return fmt.Errorf("sora s3 profiles store cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized := normalizeSoraS3ProfilesStore(*store)
|
|
||||||
data, err := json.Marshal(normalized)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal sora s3 profiles: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updates := map[string]string{
|
|
||||||
SettingKeySoraS3Profiles: string(data),
|
|
||||||
}
|
|
||||||
|
|
||||||
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
|
|
||||||
if active == nil {
|
|
||||||
updates[SettingKeySoraS3Enabled] = "false"
|
|
||||||
updates[SettingKeySoraS3Endpoint] = ""
|
|
||||||
updates[SettingKeySoraS3Region] = ""
|
|
||||||
updates[SettingKeySoraS3Bucket] = ""
|
|
||||||
updates[SettingKeySoraS3AccessKeyID] = ""
|
|
||||||
updates[SettingKeySoraS3Prefix] = ""
|
|
||||||
updates[SettingKeySoraS3ForcePathStyle] = "false"
|
|
||||||
updates[SettingKeySoraS3CDNURL] = ""
|
|
||||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
|
|
||||||
updates[SettingKeySoraS3SecretAccessKey] = ""
|
|
||||||
} else {
|
|
||||||
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
|
|
||||||
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
|
|
||||||
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
|
|
||||||
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
|
|
||||||
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
|
|
||||||
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
|
|
||||||
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
|
|
||||||
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
|
|
||||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
|
|
||||||
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.onUpdate != nil {
|
|
||||||
s.onUpdate()
|
|
||||||
}
|
|
||||||
if s.onS3Update != nil {
|
|
||||||
s.onS3Update()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
|
||||||
keys := []string{
|
|
||||||
SettingKeySoraS3Enabled,
|
|
||||||
SettingKeySoraS3Endpoint,
|
|
||||||
SettingKeySoraS3Region,
|
|
||||||
SettingKeySoraS3Bucket,
|
|
||||||
SettingKeySoraS3AccessKeyID,
|
|
||||||
SettingKeySoraS3SecretAccessKey,
|
|
||||||
SettingKeySoraS3Prefix,
|
|
||||||
SettingKeySoraS3ForcePathStyle,
|
|
||||||
SettingKeySoraS3CDNURL,
|
|
||||||
SettingKeySoraDefaultStorageQuotaBytes,
|
|
||||||
}
|
|
||||||
|
|
||||||
values, err := s.settingRepo.GetMultiple(ctx, keys)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := &SoraS3Settings{
|
|
||||||
Enabled: values[SettingKeySoraS3Enabled] == "true",
|
|
||||||
Endpoint: values[SettingKeySoraS3Endpoint],
|
|
||||||
Region: values[SettingKeySoraS3Region],
|
|
||||||
Bucket: values[SettingKeySoraS3Bucket],
|
|
||||||
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
|
|
||||||
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
|
|
||||||
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
|
|
||||||
Prefix: values[SettingKeySoraS3Prefix],
|
|
||||||
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
|
|
||||||
CDNURL: values[SettingKeySoraS3CDNURL],
|
|
||||||
}
|
|
||||||
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
|
|
||||||
result.DefaultStorageQuotaBytes = v
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
|
|
||||||
seen := make(map[string]struct{}, len(store.Items))
|
|
||||||
normalized := soraS3ProfilesStore{
|
|
||||||
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
|
|
||||||
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
|
|
||||||
}
|
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
|
|
||||||
for idx := range store.Items {
|
|
||||||
item := store.Items[idx]
|
|
||||||
item.ProfileID = strings.TrimSpace(item.ProfileID)
|
|
||||||
if item.ProfileID == "" {
|
|
||||||
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
|
|
||||||
}
|
|
||||||
if _, exists := seen[item.ProfileID]; exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[item.ProfileID] = struct{}{}
|
|
||||||
|
|
||||||
item.Name = strings.TrimSpace(item.Name)
|
|
||||||
if item.Name == "" {
|
|
||||||
item.Name = item.ProfileID
|
|
||||||
}
|
|
||||||
item.Endpoint = strings.TrimSpace(item.Endpoint)
|
|
||||||
item.Region = strings.TrimSpace(item.Region)
|
|
||||||
item.Bucket = strings.TrimSpace(item.Bucket)
|
|
||||||
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
|
|
||||||
item.Prefix = strings.TrimSpace(item.Prefix)
|
|
||||||
item.CDNURL = strings.TrimSpace(item.CDNURL)
|
|
||||||
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
|
|
||||||
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
|
|
||||||
if item.UpdatedAt == "" {
|
|
||||||
item.UpdatedAt = now
|
|
||||||
}
|
|
||||||
normalized.Items = append(normalized.Items, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(normalized.Items) == 0 {
|
|
||||||
normalized.ActiveProfileID = ""
|
|
||||||
return normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
|
|
||||||
return normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized.ActiveProfileID = normalized.Items[0].ProfileID
|
|
||||||
return normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
|
|
||||||
if store == nil {
|
|
||||||
return &SoraS3ProfileList{}
|
|
||||||
}
|
|
||||||
items := make([]SoraS3Profile, 0, len(store.Items))
|
|
||||||
for idx := range store.Items {
|
|
||||||
item := store.Items[idx]
|
|
||||||
items = append(items, SoraS3Profile{
|
|
||||||
ProfileID: item.ProfileID,
|
|
||||||
Name: item.Name,
|
|
||||||
IsActive: item.ProfileID == store.ActiveProfileID,
|
|
||||||
Enabled: item.Enabled,
|
|
||||||
Endpoint: item.Endpoint,
|
|
||||||
Region: item.Region,
|
|
||||||
Bucket: item.Bucket,
|
|
||||||
AccessKeyID: item.AccessKeyID,
|
|
||||||
SecretAccessKey: item.SecretAccessKey,
|
|
||||||
SecretAccessKeyConfigured: item.SecretAccessKey != "",
|
|
||||||
Prefix: item.Prefix,
|
|
||||||
ForcePathStyle: item.ForcePathStyle,
|
|
||||||
CDNURL: item.CDNURL,
|
|
||||||
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
|
|
||||||
UpdatedAt: item.UpdatedAt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &SoraS3ProfileList{
|
|
||||||
ActiveProfileID: store.ActiveProfileID,
|
|
||||||
Items: items,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
|
|
||||||
for idx := range items {
|
|
||||||
if items[idx].ProfileID == activeProfileID {
|
|
||||||
return &items[idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(items) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &items[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
|
|
||||||
for idx := range items {
|
|
||||||
if items[idx].ProfileID == profileID {
|
|
||||||
return &items[idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
|
|
||||||
for idx := range items {
|
|
||||||
if items[idx].ProfileID == activeProfileID {
|
|
||||||
return &items[idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(items) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &items[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
|
|
||||||
for idx := range items {
|
|
||||||
if items[idx].ProfileID == profileID {
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
|
|
||||||
return findSoraS3ProfileIndex(items, profileID) >= 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
|
|
||||||
if settings == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if settings.Enabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.Endpoint) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.Region) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.Bucket) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.AccessKeyID) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if settings.SecretAccessKey != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.Prefix) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(settings.CDNURL) != "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return settings.DefaultStorageQuotaBytes == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxInt64(value int64, min int64) int64 {
|
|
||||||
if value < min {
|
|
||||||
return min
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|||||||
@ -41,7 +41,6 @@ type SystemSettings struct {
|
|||||||
HideCcsImportButton bool
|
HideCcsImportButton bool
|
||||||
PurchaseSubscriptionEnabled bool
|
PurchaseSubscriptionEnabled bool
|
||||||
PurchaseSubscriptionURL string
|
PurchaseSubscriptionURL string
|
||||||
SoraClientEnabled bool
|
|
||||||
CustomMenuItems string // JSON array of custom menu items
|
CustomMenuItems string // JSON array of custom menu items
|
||||||
CustomEndpoints string // JSON array of custom endpoints
|
CustomEndpoints string // JSON array of custom endpoints
|
||||||
|
|
||||||
@ -107,7 +106,6 @@ type PublicSettings struct {
|
|||||||
|
|
||||||
PurchaseSubscriptionEnabled bool
|
PurchaseSubscriptionEnabled bool
|
||||||
PurchaseSubscriptionURL string
|
PurchaseSubscriptionURL string
|
||||||
SoraClientEnabled bool
|
|
||||||
CustomMenuItems string // JSON array of custom menu items
|
CustomMenuItems string // JSON array of custom menu items
|
||||||
CustomEndpoints string // JSON array of custom endpoints
|
CustomEndpoints string // JSON array of custom endpoints
|
||||||
|
|
||||||
@ -116,46 +114,6 @@ type PublicSettings struct {
|
|||||||
Version string
|
Version string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraS3Settings Sora S3 存储配置
|
|
||||||
type SoraS3Settings struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
|
|
||||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraS3Profile Sora S3 多配置项(服务内部模型)
|
|
||||||
type SoraS3Profile struct {
|
|
||||||
ProfileID string `json:"profile_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
IsActive bool `json:"is_active"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
AccessKeyID string `json:"access_key_id"`
|
|
||||||
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
|
|
||||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
ForcePathStyle bool `json:"force_path_style"`
|
|
||||||
CDNURL string `json:"cdn_url"`
|
|
||||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
|
||||||
UpdatedAt string `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraS3ProfileList Sora S3 多配置列表
|
|
||||||
type SoraS3ProfileList struct {
|
|
||||||
ActiveProfileID string `json:"active_profile_id"`
|
|
||||||
Items []SoraS3Profile `json:"items"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
|
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
|
||||||
type StreamTimeoutSettings struct {
|
type StreamTimeoutSettings struct {
|
||||||
// Enabled 是否启用流超时处理
|
// Enabled 是否启用流超时处理
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// SoraAccountRepository Sora 账号扩展表仓储接口
|
|
||||||
// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
|
|
||||||
//
|
|
||||||
// 设计说明:
|
|
||||||
// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
|
|
||||||
// - Sora gateway 优先读取此表的字段以获得更好的查询性能
|
|
||||||
// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
|
|
||||||
// - Token 刷新时需要同时更新两个表以保持数据一致性
|
|
||||||
type SoraAccountRepository interface {
|
|
||||||
// Upsert 创建或更新 Sora 账号扩展信息
|
|
||||||
// accountID: 关联的 accounts.id
|
|
||||||
// updates: 要更新的字段,支持 access_token、refresh_token、session_token
|
|
||||||
//
|
|
||||||
// 如果记录不存在则创建,存在则更新。
|
|
||||||
// 用于:
|
|
||||||
// 1. 创建 Sora 账号时初始化扩展表
|
|
||||||
// 2. Token 刷新时同步更新扩展表
|
|
||||||
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
|
|
||||||
|
|
||||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
|
||||||
// 返回 nil, nil 表示记录不存在(非错误)
|
|
||||||
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
|
|
||||||
|
|
||||||
// Delete 删除 Sora 账号扩展信息
|
|
||||||
// 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
|
|
||||||
Delete(ctx context.Context, accountID int64) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraAccount Sora 账号扩展信息
|
|
||||||
// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
|
|
||||||
type SoraAccount struct {
|
|
||||||
AccountID int64 // 关联的 accounts.id
|
|
||||||
AccessToken string // OAuth access_token
|
|
||||||
RefreshToken string // OAuth refresh_token
|
|
||||||
SessionToken string // Session token(可选,用于 ST→AT 兜底)
|
|
||||||
}
|
|
||||||
@ -1,117 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SoraClient 定义直连 Sora 的任务操作接口。
|
|
||||||
type SoraClient interface {
|
|
||||||
Enabled() bool
|
|
||||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
|
||||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
|
||||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
|
||||||
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
|
|
||||||
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
|
|
||||||
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
|
|
||||||
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
|
|
||||||
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
|
|
||||||
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
|
|
||||||
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
|
|
||||||
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
|
|
||||||
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
|
|
||||||
DeletePost(ctx context.Context, account *Account, postID string) error
|
|
||||||
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
|
|
||||||
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
|
||||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
|
||||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImageRequest 图片生成请求参数
|
|
||||||
type SoraImageRequest struct {
|
|
||||||
Prompt string
|
|
||||||
Width int
|
|
||||||
Height int
|
|
||||||
MediaID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoRequest 视频生成请求参数
|
|
||||||
type SoraVideoRequest struct {
|
|
||||||
Prompt string
|
|
||||||
Orientation string
|
|
||||||
Frames int
|
|
||||||
Model string
|
|
||||||
Size string
|
|
||||||
VideoCount int
|
|
||||||
MediaID string
|
|
||||||
RemixTargetID string
|
|
||||||
CameoIDs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraStoryboardRequest 分镜视频生成请求参数
|
|
||||||
type SoraStoryboardRequest struct {
|
|
||||||
Prompt string
|
|
||||||
Orientation string
|
|
||||||
Frames int
|
|
||||||
Model string
|
|
||||||
Size string
|
|
||||||
MediaID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraImageTaskStatus 图片任务状态
|
|
||||||
type SoraImageTaskStatus struct {
|
|
||||||
ID string
|
|
||||||
Status string
|
|
||||||
ProgressPct float64
|
|
||||||
URLs []string
|
|
||||||
ErrorMsg string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraVideoTaskStatus 视频任务状态
|
|
||||||
type SoraVideoTaskStatus struct {
|
|
||||||
ID string
|
|
||||||
Status string
|
|
||||||
ProgressPct int
|
|
||||||
URLs []string
|
|
||||||
GenerationID string
|
|
||||||
ErrorMsg string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraCameoStatus 角色处理中间态
|
|
||||||
type SoraCameoStatus struct {
|
|
||||||
Status string
|
|
||||||
StatusMessage string
|
|
||||||
DisplayNameHint string
|
|
||||||
UsernameHint string
|
|
||||||
ProfileAssetURL string
|
|
||||||
InstructionSetHint any
|
|
||||||
InstructionSet any
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraCharacterFinalizeRequest 角色定稿请求参数
|
|
||||||
type SoraCharacterFinalizeRequest struct {
|
|
||||||
CameoID string
|
|
||||||
Username string
|
|
||||||
DisplayName string
|
|
||||||
ProfileAssetPointer string
|
|
||||||
InstructionSet any
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraUpstreamError 上游错误
|
|
||||||
type SoraUpstreamError struct {
|
|
||||||
StatusCode int
|
|
||||||
Message string
|
|
||||||
Headers http.Header
|
|
||||||
Body []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SoraUpstreamError) Error() string {
|
|
||||||
if e == nil {
|
|
||||||
return "sora upstream error"
|
|
||||||
}
|
|
||||||
if e.Message != "" {
|
|
||||||
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,564 +0,0 @@
|
|||||||
//go:build unit
|
|
||||||
|
|
||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ SoraClient = (*stubSoraClientForPoll)(nil)
|
|
||||||
|
|
||||||
type stubSoraClientForPoll struct {
|
|
||||||
imageStatus *SoraImageTaskStatus
|
|
||||||
videoStatus *SoraVideoTaskStatus
|
|
||||||
imageCalls int
|
|
||||||
videoCalls int
|
|
||||||
enhanced string
|
|
||||||
enhanceErr error
|
|
||||||
storyboard bool
|
|
||||||
videoReq SoraVideoRequest
|
|
||||||
parseErr error
|
|
||||||
postCalls int
|
|
||||||
deleteCalls int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
|
||||||
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
|
||||||
return "task-image", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
|
||||||
s.videoReq = req
|
|
||||||
return "task-video", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
|
||||||
s.storyboard = true
|
|
||||||
return "task-video", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
|
||||||
return "cameo-1", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
|
||||||
return &SoraCameoStatus{
|
|
||||||
Status: "finalized",
|
|
||||||
StatusMessage: "Completed",
|
|
||||||
DisplayNameHint: "Character",
|
|
||||||
UsernameHint: "user.character",
|
|
||||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
|
||||||
return []byte("avatar"), nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
|
||||||
return "asset-pointer", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
|
||||||
return "character-1", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
|
||||||
s.postCalls++
|
|
||||||
return "s_post", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
|
|
||||||
s.deleteCalls++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
|
||||||
if s.parseErr != nil {
|
|
||||||
return "", s.parseErr
|
|
||||||
}
|
|
||||||
return "https://example.com/no-watermark.mp4", nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
|
||||||
if s.enhanced != "" {
|
|
||||||
return s.enhanced, s.enhanceErr
|
|
||||||
}
|
|
||||||
return "enhanced prompt", s.enhanceErr
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
|
||||||
s.imageCalls++
|
|
||||||
return s.imageStatus, nil
|
|
||||||
}
|
|
||||||
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
|
||||||
s.videoCalls++
|
|
||||||
return s.videoStatus, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
imageStatus: &SoraImageTaskStatus{
|
|
||||||
Status: "completed",
|
|
||||||
URLs: []string{"https://example.com/a.png"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
|
|
||||||
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []string{"https://example.com/a.png"}, urls)
|
|
||||||
require.Equal(t, 1, client.imageCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
enhanced: "cinematic prompt",
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{
|
|
||||||
ID: 1,
|
|
||||||
Platform: PlatformSora,
|
|
||||||
Status: StatusActive,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"model_mapping": map[string]any{
|
|
||||||
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Equal(t, "prompt", result.MediaType)
|
|
||||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
|
||||||
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
|
||||||
Status: "completed",
|
|
||||||
URLs: []string{"https://example.com/v.mp4"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
|
||||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.True(t, client.storyboard)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
|
||||||
Status: "completed",
|
|
||||||
URLs: []string{"https://example.com/v.mp4"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
|
||||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Equal(t, 3, client.videoReq.VideoCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
|
||||||
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Equal(t, "prompt", result.MediaType)
|
|
||||||
require.Equal(t, 0, client.videoCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
|
||||||
Status: "completed",
|
|
||||||
URLs: []string{"https://example.com/original.mp4"},
|
|
||||||
GenerationID: "gen_1",
|
|
||||||
},
|
|
||||||
parseErr: errors.New("parse failed"),
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
|
||||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
|
|
||||||
require.Equal(t, 1, client.postCalls)
|
|
||||||
require.Equal(t, 0, client.deleteCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
|
||||||
Status: "completed",
|
|
||||||
URLs: []string{"https://example.com/original.mp4"},
|
|
||||||
GenerationID: "gen_1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
|
||||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
|
|
||||||
|
|
||||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
|
|
||||||
require.Equal(t, 1, client.postCalls)
|
|
||||||
require.Equal(t, 1, client.deleteCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
|
||||||
client := &stubSoraClientForPoll{
|
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
|
||||||
Status: "failed",
|
|
||||||
ErrorMsg: "reject",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 1,
|
|
||||||
MaxPollAttempts: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
|
||||||
|
|
||||||
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, status)
|
|
||||||
require.Contains(t, err.Error(), "reject")
|
|
||||||
require.Equal(t, 1, client.videoCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Gateway: config.GatewayConfig{
|
|
||||||
SoraMediaSigningKey: "test-key",
|
|
||||||
SoraMediaSignedURLTTLSeconds: 600,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
service := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
|
|
||||||
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
|
|
||||||
require.Contains(t, url, "/sora/media-signed")
|
|
||||||
require.Contains(t, url, "expires=")
|
|
||||||
require.Contains(t, url, "sig=")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
result := svc.normalizeSoraMediaURLs(nil)
|
|
||||||
require.Empty(t, result)
|
|
||||||
|
|
||||||
result = svc.normalizeSoraMediaURLs([]string{})
|
|
||||||
require.Empty(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
|
|
||||||
result := svc.normalizeSoraMediaURLs(urls)
|
|
||||||
require.Equal(t, urls, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
|
|
||||||
cfg := &config.Config{}
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
|
|
||||||
result := svc.normalizeSoraMediaURLs(urls)
|
|
||||||
require.Len(t, result, 2)
|
|
||||||
require.Contains(t, result[0], "/sora/media")
|
|
||||||
require.Contains(t, result[1], "/sora/media")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
|
|
||||||
result := svc.normalizeSoraMediaURLs(urls)
|
|
||||||
require.Len(t, result, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildSoraContent_Image(t *testing.T) {
|
|
||||||
content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
|
|
||||||
require.Contains(t, content, "")
|
|
||||||
require.Contains(t, content, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildSoraContent_Video(t *testing.T) {
|
|
||||||
content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
|
|
||||||
require.Contains(t, content, "<video src='https://a.com/v.mp4'")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildSoraContent_VideoEmpty(t *testing.T) {
|
|
||||||
content := buildSoraContent("video", nil)
|
|
||||||
require.Empty(t, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildSoraContent_Prompt(t *testing.T) {
|
|
||||||
content := buildSoraContent("prompt", nil)
|
|
||||||
require.Empty(t, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraImageSizeFromModel(t *testing.T) {
|
|
||||||
require.Equal(t, "360", soraImageSizeFromModel("gpt-image"))
|
|
||||||
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-landscape"))
|
|
||||||
require.Equal(t, "540", soraImageSizeFromModel("gpt-image-portrait"))
|
|
||||||
require.Equal(t, "540", soraImageSizeFromModel("something-landscape"))
|
|
||||||
require.Equal(t, "360", soraImageSizeFromModel("unknown-model"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirstMediaURL(t *testing.T) {
|
|
||||||
require.Equal(t, "", firstMediaURL(nil))
|
|
||||||
require.Equal(t, "", firstMediaURL([]string{}))
|
|
||||||
require.Equal(t, "a", firstMediaURL([]string{"a", "b"}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraProErrorMessage(t *testing.T) {
|
|
||||||
require.Contains(t, soraProErrorMessage("sora2pro-hd", ""), "Pro-HD")
|
|
||||||
require.Contains(t, soraProErrorMessage("sora2pro", ""), "Pro")
|
|
||||||
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(rec)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
|
|
||||||
|
|
||||||
body := rec.Body.String()
|
|
||||||
require.Contains(t, body, "event: error\n")
|
|
||||||
require.Contains(t, body, "data: [DONE]\n\n")
|
|
||||||
|
|
||||||
lines := strings.Split(body, "\n")
|
|
||||||
require.GreaterOrEqual(t, len(lines), 2)
|
|
||||||
require.Equal(t, "event: error", lines[0])
|
|
||||||
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
|
||||||
|
|
||||||
data := strings.TrimPrefix(lines[1], "data: ")
|
|
||||||
var parsed map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
|
|
||||||
errObj, ok := parsed["error"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "upstream_error", errObj["type"])
|
|
||||||
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
sourceHeaders := http.Header{}
|
|
||||||
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
|
||||||
|
|
||||||
err := svc.handleSoraRequestError(
|
|
||||||
context.Background(),
|
|
||||||
&Account{ID: 1, Platform: PlatformSora},
|
|
||||||
&SoraUpstreamError{
|
|
||||||
StatusCode: http.StatusForbidden,
|
|
||||||
Message: "forbidden",
|
|
||||||
Headers: sourceHeaders,
|
|
||||||
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
|
|
||||||
},
|
|
||||||
"sora2-landscape-10s",
|
|
||||||
nil,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
|
|
||||||
var failoverErr *UpstreamFailoverError
|
|
||||||
require.ErrorAs(t, err, &failoverErr)
|
|
||||||
require.NotNil(t, failoverErr.ResponseHeaders)
|
|
||||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
|
||||||
|
|
||||||
sourceHeaders.Set("cf-ray", "mutated-after-return")
|
|
||||||
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(404))
|
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
|
||||||
require.False(t, svc.shouldFailoverUpstreamError(200))
|
|
||||||
require.False(t, svc.shouldFailoverUpstreamError(400))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithSoraTimeout_NilService(t *testing.T) {
|
|
||||||
var svc *SoraGatewayService
|
|
||||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
|
||||||
require.NotNil(t, ctx)
|
|
||||||
require.Nil(t, cancel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithSoraTimeout_ZeroTimeout(t *testing.T) {
|
|
||||||
cfg := &config.Config{}
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
|
||||||
require.NotNil(t, ctx)
|
|
||||||
require.Nil(t, cancel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithSoraTimeout_PositiveTimeout(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Gateway: config.GatewayConfig{
|
|
||||||
SoraRequestTimeoutSeconds: 30,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
ctx, cancel := svc.withSoraTimeout(context.Background(), false)
|
|
||||||
require.NotNil(t, ctx)
|
|
||||||
require.NotNil(t, cancel)
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPollInterval(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
PollIntervalSeconds: 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
require.Equal(t, 5*time.Second, svc.pollInterval())
|
|
||||||
|
|
||||||
// 默认值
|
|
||||||
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
require.True(t, svc2.pollInterval() > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPollMaxAttempts(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Sora: config.SoraConfig{
|
|
||||||
Client: config.SoraClientConfig{
|
|
||||||
MaxPollAttempts: 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, cfg)
|
|
||||||
require.Equal(t, 100, svc.pollMaxAttempts())
|
|
||||||
|
|
||||||
// 默认值
|
|
||||||
svc2 := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
|
||||||
require.True(t, svc2.pollMaxAttempts() > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
|
|
||||||
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
|
||||||
encoded := "data:image/png;base64,aGVsbG8="
|
|
||||||
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, data)
|
|
||||||
require.Contains(t, filename, ".png")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
|
|
||||||
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
|
||||||
body := map[string]any{
|
|
||||||
"watermark_free": float64(1),
|
|
||||||
"watermark_fallback_on_failure": float64(0),
|
|
||||||
}
|
|
||||||
opts := parseSoraWatermarkOptions(body)
|
|
||||||
require.True(t, opts.Enabled)
|
|
||||||
require.False(t, opts.FallbackOnFailure)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseSoraVideoCount(t *testing.T) {
|
|
||||||
require.Equal(t, 1, parseSoraVideoCount(nil))
|
|
||||||
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
|
|
||||||
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
|
|
||||||
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
|
|
||||||
}
|
|
||||||
@ -1,532 +0,0 @@
|
|||||||
//nolint:unused
|
|
||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|
||||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
|
||||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
|
||||||
|
|
||||||
const soraRewriteBufferLimit = 2048
|
|
||||||
|
|
||||||
type soraStreamingResult struct {
|
|
||||||
mediaType string
|
|
||||||
mediaURLs []string
|
|
||||||
imageCount int
|
|
||||||
imageSize string
|
|
||||||
firstTokenMs *int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
|
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
||||||
setOpsUpstreamError(c, 0, safeErr, "")
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: 0,
|
|
||||||
Kind: "request_error",
|
|
||||||
Message: safeErr,
|
|
||||||
})
|
|
||||||
if c != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream request failed",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
|
||||||
if s.rateLimitService == nil || account == nil || resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
|
|
||||||
upstreamMsg = msg
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamDetail := ""
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = 2048
|
|
||||||
}
|
|
||||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
||||||
}
|
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
||||||
Kind: "http_error",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
|
|
||||||
if c != nil {
|
|
||||||
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
|
|
||||||
c.JSON(resp.StatusCode, responsePayload)
|
|
||||||
}
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
|
|
||||||
if len(respBody) > 0 {
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.Unmarshal(respBody, &payload); err == nil {
|
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
|
||||||
if overrideMessage != "" {
|
|
||||||
errObj["message"] = overrideMessage
|
|
||||||
}
|
|
||||||
payload["error"] = errObj
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": overrideMessage,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
|
|
||||||
if resp == nil {
|
|
||||||
return nil, errors.New("empty response")
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientStream {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Accel-Buffering", "no")
|
|
||||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
|
||||||
c.Header("x-request-id", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w := c.Writer
|
|
||||||
flusher, _ := w.(http.Flusher)
|
|
||||||
|
|
||||||
contentBuilder := strings.Builder{}
|
|
||||||
var firstTokenMs *int
|
|
||||||
var upstreamError error
|
|
||||||
rewriteBuffer := ""
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
maxLineSize := defaultMaxLineSize
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
sendLine := func(line string) error {
|
|
||||||
if !clientStream {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if soraSSEDataRe.MatchString(line) {
|
|
||||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if data == "[DONE]" {
|
|
||||||
if rewriteBuffer != "" {
|
|
||||||
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if flushLine != "" {
|
|
||||||
if flushContent != "" {
|
|
||||||
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := sendLine(flushLine); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriteBuffer = ""
|
|
||||||
}
|
|
||||||
if err := sendLine("data: [DONE]"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
|
||||||
if errEvent != nil && upstreamError == nil {
|
|
||||||
upstreamError = errEvent
|
|
||||||
}
|
|
||||||
if contentDelta != "" {
|
|
||||||
if firstTokenMs == nil {
|
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
|
||||||
firstTokenMs = &ms
|
|
||||||
}
|
|
||||||
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := sendLine(updatedLine); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := sendLine(line); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if errors.Is(err, bufio.ErrTooLong) {
|
|
||||||
if clientStream {
|
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
|
|
||||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
|
||||||
}
|
|
||||||
if clientStream {
|
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
content := contentBuilder.String()
|
|
||||||
mediaType, mediaURLs := s.extractSoraMedia(content)
|
|
||||||
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
|
|
||||||
mediaType = "prompt"
|
|
||||||
}
|
|
||||||
imageSize := ""
|
|
||||||
imageCount := 0
|
|
||||||
if mediaType == "image" {
|
|
||||||
imageSize = soraImageSizeFromModel(originalModel)
|
|
||||||
imageCount = len(mediaURLs)
|
|
||||||
}
|
|
||||||
|
|
||||||
if upstreamError != nil && !clientStream {
|
|
||||||
if c != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": upstreamError.Error(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, upstreamError
|
|
||||||
}
|
|
||||||
|
|
||||||
if !clientStream {
|
|
||||||
response := buildSoraNonStreamResponse(content, originalModel)
|
|
||||||
if len(mediaURLs) > 0 {
|
|
||||||
response["media_url"] = mediaURLs[0]
|
|
||||||
if len(mediaURLs) > 1 {
|
|
||||||
response["media_urls"] = mediaURLs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &soraStreamingResult{
|
|
||||||
mediaType: mediaType,
|
|
||||||
mediaURLs: mediaURLs,
|
|
||||||
imageCount: imageCount,
|
|
||||||
imageSize: imageSize,
|
|
||||||
firstTokenMs: firstTokenMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
|
||||||
if strings.TrimSpace(data) == "" {
|
|
||||||
return "data: ", "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
|
||||||
return "data: " + data, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
|
||||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
|
||||||
return "data: " + data, "", errors.New(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
|
|
||||||
payload["model"] = originalModel
|
|
||||||
}
|
|
||||||
|
|
||||||
contentDelta, updated := extractSoraContent(payload)
|
|
||||||
if updated {
|
|
||||||
var rewritten string
|
|
||||||
if rewriteBuffer != nil {
|
|
||||||
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
|
||||||
} else {
|
|
||||||
rewritten = s.rewriteSoraContent(contentDelta)
|
|
||||||
}
|
|
||||||
if rewritten != contentDelta {
|
|
||||||
applySoraContent(payload, rewritten)
|
|
||||||
contentDelta = rewritten
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedData, err := jsonMarshalRaw(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "data: " + data, contentDelta, nil
|
|
||||||
}
|
|
||||||
return "data: " + string(updatedData), contentDelta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractSoraContent(payload map[string]any) (string, bool) {
|
|
||||||
choices, ok := payload["choices"].([]any)
|
|
||||||
if !ok || len(choices) == 0 {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
choice, ok := choices[0].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
|
||||||
if content, ok := delta["content"].(string); ok {
|
|
||||||
return content, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if message, ok := choice["message"].(map[string]any); ok {
|
|
||||||
if content, ok := message["content"].(string); ok {
|
|
||||||
return content, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func applySoraContent(payload map[string]any, content string) {
|
|
||||||
choices, ok := payload["choices"].([]any)
|
|
||||||
if !ok || len(choices) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
choice, ok := choices[0].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
|
||||||
delta["content"] = content
|
|
||||||
choice["delta"] = delta
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if message, ok := choice["message"].(map[string]any); ok {
|
|
||||||
message["content"] = content
|
|
||||||
choice["message"] = message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
|
||||||
if buffer == nil {
|
|
||||||
return s.rewriteSoraContent(contentDelta)
|
|
||||||
}
|
|
||||||
if contentDelta == "" && *buffer == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
combined := *buffer + contentDelta
|
|
||||||
rewritten := s.rewriteSoraContent(combined)
|
|
||||||
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
|
||||||
if bufferStart < 0 {
|
|
||||||
*buffer = ""
|
|
||||||
return rewritten
|
|
||||||
}
|
|
||||||
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
|
||||||
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
|
||||||
}
|
|
||||||
output := rewritten[:bufferStart]
|
|
||||||
*buffer = rewritten[bufferStart:]
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
|
||||||
minIndex := -1
|
|
||||||
start := 0
|
|
||||||
for {
|
|
||||||
idx := strings.Index(content[start:], "![")
|
|
||||||
if idx < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
idx += start
|
|
||||||
if !hasSoraImageMatchAt(content, idx) {
|
|
||||||
if minIndex == -1 || idx < minIndex {
|
|
||||||
minIndex = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
start = idx + 2
|
|
||||||
}
|
|
||||||
lower := strings.ToLower(content)
|
|
||||||
start = 0
|
|
||||||
for {
|
|
||||||
idx := strings.Index(lower[start:], "<video")
|
|
||||||
if idx < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
idx += start
|
|
||||||
if !hasSoraVideoMatchAt(content, idx) {
|
|
||||||
if minIndex == -1 || idx < minIndex {
|
|
||||||
minIndex = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
start = idx + len("<video")
|
|
||||||
}
|
|
||||||
return minIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSoraImageMatchAt(content string, idx int) bool {
|
|
||||||
if idx < 0 || idx >= len(content) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
|
||||||
return loc != nil && loc[0] == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSoraVideoMatchAt(content string, idx int) bool {
|
|
||||||
if idx < 0 || idx >= len(content) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
|
||||||
return loc != nil && loc[0] == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
|
||||||
if content == "" {
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
|
|
||||||
sub := soraImageMarkdownRe.FindStringSubmatch(match)
|
|
||||||
if len(sub) < 2 {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraURL(sub[1])
|
|
||||||
if rewritten == sub[1] {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
return strings.Replace(match, sub[1], rewritten, 1)
|
|
||||||
})
|
|
||||||
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
|
|
||||||
sub := soraVideoHTMLRe.FindStringSubmatch(match)
|
|
||||||
if len(sub) < 2 {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraURL(sub[1])
|
|
||||||
if rewritten == sub[1] {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
return strings.Replace(match, sub[1], rewritten, 1)
|
|
||||||
})
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
|
||||||
if buffer == "" {
|
|
||||||
return "", "", nil
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraContent(buffer)
|
|
||||||
payload := map[string]any{
|
|
||||||
"choices": []any{
|
|
||||||
map[string]any{
|
|
||||||
"delta": map[string]any{
|
|
||||||
"content": rewritten,
|
|
||||||
},
|
|
||||||
"index": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if originalModel != "" {
|
|
||||||
payload["model"] = originalModel
|
|
||||||
}
|
|
||||||
updatedData, err := jsonMarshalRaw(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return "data: " + string(updatedData), rewritten, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
parsed, err := url.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
path := parsed.Path
|
|
||||||
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
return s.buildSoraMediaURL(path, parsed.RawQuery)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
|
|
||||||
if content == "" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
|
|
||||||
return "video", []string{match[1]}
|
|
||||||
}
|
|
||||||
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
|
|
||||||
if len(imageMatches) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
urls := make([]string, 0, len(imageMatches))
|
|
||||||
for _, match := range imageMatches {
|
|
||||||
if len(match) > 1 {
|
|
||||||
urls = append(urls, match[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "image", urls
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSoraPromptEnhanceModel(model string) bool {
|
|
||||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
|
|
||||||
}
|
|
||||||
@ -1,63 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SoraGeneration 代表一条 Sora 客户端生成记录。
|
|
||||||
type SoraGeneration struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
UserID int64 `json:"user_id"`
|
|
||||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MediaType string `json:"media_type"` // video / image
|
|
||||||
Status string `json:"status"` // pending / generating / completed / failed / cancelled
|
|
||||||
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
|
|
||||||
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
|
|
||||||
FileSizeBytes int64 `json:"file_size_bytes"`
|
|
||||||
StorageType string `json:"storage_type"` // s3 / local / upstream / none
|
|
||||||
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
|
|
||||||
UpstreamTaskID string `json:"upstream_task_id"`
|
|
||||||
ErrorMessage string `json:"error_message"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora 生成记录状态常量
|
|
||||||
const (
|
|
||||||
SoraGenStatusPending = "pending"
|
|
||||||
SoraGenStatusGenerating = "generating"
|
|
||||||
SoraGenStatusCompleted = "completed"
|
|
||||||
SoraGenStatusFailed = "failed"
|
|
||||||
SoraGenStatusCancelled = "cancelled"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sora 存储类型常量
|
|
||||||
const (
|
|
||||||
SoraStorageTypeS3 = "s3"
|
|
||||||
SoraStorageTypeLocal = "local"
|
|
||||||
SoraStorageTypeUpstream = "upstream"
|
|
||||||
SoraStorageTypeNone = "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SoraGenerationListParams 查询生成记录的参数。
|
|
||||||
type SoraGenerationListParams struct {
|
|
||||||
UserID int64
|
|
||||||
Status string // 可选筛选
|
|
||||||
StorageType string // 可选筛选
|
|
||||||
MediaType string // 可选筛选
|
|
||||||
Page int
|
|
||||||
PageSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraGenerationRepository 生成记录持久化接口。
|
|
||||||
type SoraGenerationRepository interface {
|
|
||||||
Create(ctx context.Context, gen *SoraGeneration) error
|
|
||||||
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
|
|
||||||
Update(ctx context.Context, gen *SoraGeneration) error
|
|
||||||
Delete(ctx context.Context, id int64) error
|
|
||||||
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
|
|
||||||
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
|
|
||||||
}
|
|
||||||
@ -1,332 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
|
|
||||||
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
|
|
||||||
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
|
|
||||||
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
|
|
||||||
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
|
|
||||||
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
|
|
||||||
)
|
|
||||||
|
|
||||||
const soraGenerationActiveLimit = 3
|
|
||||||
|
|
||||||
type soraGenerationRepoAtomicCreator interface {
|
|
||||||
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type soraGenerationRepoConditionalUpdater interface {
|
|
||||||
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
|
|
||||||
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
|
|
||||||
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
|
|
||||||
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
|
|
||||||
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
|
|
||||||
type SoraGenerationService struct {
|
|
||||||
genRepo SoraGenerationRepository
|
|
||||||
s3Storage *SoraS3Storage
|
|
||||||
quotaService *SoraQuotaService
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSoraGenerationService 创建生成记录服务。
|
|
||||||
func NewSoraGenerationService(
|
|
||||||
genRepo SoraGenerationRepository,
|
|
||||||
s3Storage *SoraS3Storage,
|
|
||||||
quotaService *SoraQuotaService,
|
|
||||||
) *SoraGenerationService {
|
|
||||||
return &SoraGenerationService{
|
|
||||||
genRepo: genRepo,
|
|
||||||
s3Storage: s3Storage,
|
|
||||||
quotaService: quotaService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePending 创建一条 pending 状态的生成记录。
|
|
||||||
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
UserID: userID,
|
|
||||||
APIKeyID: apiKeyID,
|
|
||||||
Model: model,
|
|
||||||
Prompt: prompt,
|
|
||||||
MediaType: mediaType,
|
|
||||||
Status: SoraGenStatusPending,
|
|
||||||
StorageType: SoraStorageTypeNone,
|
|
||||||
}
|
|
||||||
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
|
|
||||||
if err := atomicCreator.CreatePendingWithLimit(
|
|
||||||
ctx,
|
|
||||||
gen,
|
|
||||||
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
|
|
||||||
soraGenerationActiveLimit,
|
|
||||||
); err != nil {
|
|
||||||
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("create generation: %w", err)
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
|
||||||
return gen, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.genRepo.Create(ctx, gen); err != nil {
|
|
||||||
return nil, fmt.Errorf("create generation: %w", err)
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
|
||||||
return gen, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkGenerating 标记为生成中。
|
|
||||||
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
|
|
||||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
|
||||||
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.Status != SoraGenStatusPending {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
gen.Status = SoraGenStatusGenerating
|
|
||||||
gen.UpstreamTaskID = upstreamTaskID
|
|
||||||
return s.genRepo.Update(ctx, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkCompleted 标记为已完成。
|
|
||||||
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
|
|
||||||
now := time.Now()
|
|
||||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
|
||||||
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
gen.Status = SoraGenStatusCompleted
|
|
||||||
gen.MediaURL = mediaURL
|
|
||||||
gen.MediaURLs = mediaURLs
|
|
||||||
gen.StorageType = storageType
|
|
||||||
gen.S3ObjectKeys = s3Keys
|
|
||||||
gen.FileSizeBytes = fileSizeBytes
|
|
||||||
gen.CompletedAt = &now
|
|
||||||
return s.genRepo.Update(ctx, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkFailed 标记为失败。
|
|
||||||
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
|
|
||||||
now := time.Now()
|
|
||||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
|
||||||
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
gen.Status = SoraGenStatusFailed
|
|
||||||
gen.ErrorMessage = errMsg
|
|
||||||
gen.CompletedAt = &now
|
|
||||||
return s.genRepo.Update(ctx, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkCancelled 标记为已取消。
|
|
||||||
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
|
|
||||||
now := time.Now()
|
|
||||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
|
||||||
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
return ErrSoraGenerationNotActive
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
|
||||||
return ErrSoraGenerationNotActive
|
|
||||||
}
|
|
||||||
gen.Status = SoraGenStatusCancelled
|
|
||||||
gen.CompletedAt = &now
|
|
||||||
return s.genRepo.Update(ctx, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
|
|
||||||
func (s *SoraGenerationService) UpdateStorageForCompleted(
|
|
||||||
ctx context.Context,
|
|
||||||
id int64,
|
|
||||||
mediaURL string,
|
|
||||||
mediaURLs []string,
|
|
||||||
storageType string,
|
|
||||||
s3Keys []string,
|
|
||||||
fileSizeBytes int64,
|
|
||||||
) error {
|
|
||||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
|
||||||
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.Status != SoraGenStatusCompleted {
|
|
||||||
return ErrSoraGenerationStateConflict
|
|
||||||
}
|
|
||||||
gen.MediaURL = mediaURL
|
|
||||||
gen.MediaURLs = mediaURLs
|
|
||||||
gen.StorageType = storageType
|
|
||||||
gen.S3ObjectKeys = s3Keys
|
|
||||||
gen.FileSizeBytes = fileSizeBytes
|
|
||||||
return s.genRepo.Update(ctx, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByID 获取记录详情(含权限校验)。
|
|
||||||
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if gen.UserID != userID {
|
|
||||||
return nil, fmt.Errorf("无权访问此生成记录")
|
|
||||||
}
|
|
||||||
return gen, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// List 查询生成记录列表(分页 + 筛选)。
|
|
||||||
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
|
||||||
if params.Page <= 0 {
|
|
||||||
params.Page = 1
|
|
||||||
}
|
|
||||||
if params.PageSize <= 0 {
|
|
||||||
params.PageSize = 20
|
|
||||||
}
|
|
||||||
if params.PageSize > 100 {
|
|
||||||
params.PageSize = 100
|
|
||||||
}
|
|
||||||
return s.genRepo.List(ctx, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
|
|
||||||
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
|
|
||||||
gen, err := s.genRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if gen.UserID != userID {
|
|
||||||
return fmt.Errorf("无权删除此生成记录")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理 S3 文件
|
|
||||||
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
|
|
||||||
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
|
|
||||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 释放配额(S3/本地均释放)
|
|
||||||
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
|
|
||||||
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
|
|
||||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.genRepo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
|
|
||||||
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
|
|
||||||
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
|
|
||||||
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
|
|
||||||
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(gen.S3ObjectKeys) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
urls := make([]string, len(gen.S3ObjectKeys))
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var firstErr error
|
|
||||||
var errMu sync.Mutex
|
|
||||||
|
|
||||||
for idx, key := range gen.S3ObjectKeys {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(i int, objectKey string) {
|
|
||||||
defer wg.Done()
|
|
||||||
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
|
|
||||||
if err != nil {
|
|
||||||
errMu.Lock()
|
|
||||||
if firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
errMu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
urls[i] = url
|
|
||||||
}(idx, key)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if firstErr != nil {
|
|
||||||
return firstErr
|
|
||||||
}
|
|
||||||
|
|
||||||
gen.MediaURL = urls[0]
|
|
||||||
gen.MediaURLs = urls
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,881 +0,0 @@
|
|||||||
//go:build unit
|
|
||||||
|
|
||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
||||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ==================== Stub: SoraGenerationRepository ====================
|
|
||||||
|
|
||||||
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
|
|
||||||
|
|
||||||
type stubGenRepo struct {
|
|
||||||
gens map[int64]*SoraGeneration
|
|
||||||
nextID int64
|
|
||||||
createErr error
|
|
||||||
getErr error
|
|
||||||
updateErr error
|
|
||||||
deleteErr error
|
|
||||||
listErr error
|
|
||||||
countErr error
|
|
||||||
countValue int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStubGenRepo() *stubGenRepo {
|
|
||||||
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
|
|
||||||
if r.createErr != nil {
|
|
||||||
return r.createErr
|
|
||||||
}
|
|
||||||
gen.ID = r.nextID
|
|
||||||
gen.CreatedAt = time.Now()
|
|
||||||
r.nextID++
|
|
||||||
r.gens[gen.ID] = gen
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
|
|
||||||
if r.getErr != nil {
|
|
||||||
return nil, r.getErr
|
|
||||||
}
|
|
||||||
if gen, ok := r.gens[id]; ok {
|
|
||||||
return gen, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
|
|
||||||
if r.updateErr != nil {
|
|
||||||
return r.updateErr
|
|
||||||
}
|
|
||||||
r.gens[gen.ID] = gen
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
|
|
||||||
if r.deleteErr != nil {
|
|
||||||
return r.deleteErr
|
|
||||||
}
|
|
||||||
delete(r.gens, id)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
|
||||||
if r.listErr != nil {
|
|
||||||
return nil, 0, r.listErr
|
|
||||||
}
|
|
||||||
var result []*SoraGeneration
|
|
||||||
for _, gen := range r.gens {
|
|
||||||
if gen.UserID != params.UserID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if params.Status != "" && gen.Status != params.Status {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if params.StorageType != "" && gen.StorageType != params.StorageType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if params.MediaType != "" && gen.MediaType != params.MediaType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result = append(result, gen)
|
|
||||||
}
|
|
||||||
return result, int64(len(result)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
|
|
||||||
if r.countErr != nil {
|
|
||||||
return 0, r.countErr
|
|
||||||
}
|
|
||||||
if r.countValue > 0 {
|
|
||||||
return r.countValue, nil
|
|
||||||
}
|
|
||||||
var count int64
|
|
||||||
statusSet := make(map[string]struct{})
|
|
||||||
for _, s := range statuses {
|
|
||||||
statusSet[s] = struct{}{}
|
|
||||||
}
|
|
||||||
for _, gen := range r.gens {
|
|
||||||
if gen.UserID == userID {
|
|
||||||
if _, ok := statusSet[gen.Status]; ok {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
|
|
||||||
|
|
||||||
var _ UserRepository = (*stubUserRepoForQuota)(nil)
|
|
||||||
|
|
||||||
type stubUserRepoForQuota struct {
|
|
||||||
users map[int64]*User
|
|
||||||
updateErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStubUserRepoForQuota() *stubUserRepoForQuota {
|
|
||||||
return &stubUserRepoForQuota{users: make(map[int64]*User)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
|
|
||||||
if u, ok := r.users[id]; ok {
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("user not found")
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
|
|
||||||
if r.updateErr != nil {
|
|
||||||
return r.updateErr
|
|
||||||
}
|
|
||||||
r.users[user.ID] = user
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
|
|
||||||
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
|
|
||||||
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
|
|
||||||
|
|
||||||
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
|
|
||||||
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
|
|
||||||
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
|
|
||||||
storage := &SoraS3Storage{}
|
|
||||||
storage.cfg = &SoraS3Settings{
|
|
||||||
Enabled: true,
|
|
||||||
Bucket: "test-bucket",
|
|
||||||
CDNURL: cdnURL,
|
|
||||||
}
|
|
||||||
// 需要 non-nil client 使 getClient 命中缓存
|
|
||||||
storage.client = s3.New(s3.Options{})
|
|
||||||
return storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
|
|
||||||
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
|
|
||||||
func newS3StorageFailingDelete() *SoraS3Storage {
|
|
||||||
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== CreatePending ====================
|
|
||||||
|
|
||||||
func TestCreatePending_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), gen.ID)
|
|
||||||
require.Equal(t, int64(1), gen.UserID)
|
|
||||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
|
||||||
require.Equal(t, "一只猫跳舞", gen.Prompt)
|
|
||||||
require.Equal(t, "video", gen.MediaType)
|
|
||||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
|
||||||
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
|
|
||||||
require.Nil(t, gen.APIKeyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreatePending_WithAPIKeyID(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
apiKeyID := int64(42)
|
|
||||||
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, gen.APIKeyID)
|
|
||||||
require.Equal(t, int64(42), *gen.APIKeyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreatePending_RepoError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.createErr = fmt.Errorf("db write error")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, gen)
|
|
||||||
require.Contains(t, err.Error(), "create generation")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== MarkGenerating ====================
|
|
||||||
|
|
||||||
func TestMarkGenerating_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
|
|
||||||
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkGenerating_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkGenerating(context.Background(), 999, "")
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkGenerating_UpdateError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
|
||||||
repo.updateErr = fmt.Errorf("update failed")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkGenerating(context.Background(), 1, "")
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== MarkCompleted ====================
|
|
||||||
|
|
||||||
func TestMarkCompleted_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCompleted(context.Background(), 1,
|
|
||||||
"https://cdn.example.com/video.mp4",
|
|
||||||
[]string{"https://cdn.example.com/video.mp4"},
|
|
||||||
SoraStorageTypeS3,
|
|
||||||
[]string{"sora/1/2024/01/01/uuid.mp4"},
|
|
||||||
1048576,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
gen := repo.gens[1]
|
|
||||||
require.Equal(t, SoraGenStatusCompleted, gen.Status)
|
|
||||||
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
|
|
||||||
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
|
|
||||||
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
|
|
||||||
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
|
|
||||||
require.Equal(t, int64(1048576), gen.FileSizeBytes)
|
|
||||||
require.NotNil(t, gen.CompletedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCompleted_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCompleted_UpdateError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
repo.updateErr = fmt.Errorf("update failed")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== MarkFailed ====================
|
|
||||||
|
|
||||||
func TestMarkFailed_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
|
|
||||||
require.NoError(t, err)
|
|
||||||
gen := repo.gens[1]
|
|
||||||
require.Equal(t, SoraGenStatusFailed, gen.Status)
|
|
||||||
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
|
|
||||||
require.NotNil(t, gen.CompletedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkFailed_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkFailed(context.Background(), 999, "error")
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkFailed_UpdateError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
repo.updateErr = fmt.Errorf("update failed")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkFailed(context.Background(), 1, "err")
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== MarkCancelled ====================
|
|
||||||
|
|
||||||
func TestMarkCancelled_Pending(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
|
||||||
require.NotNil(t, repo.gens[1].CompletedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_Generating(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_Completed(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_Failed(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 999)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkCancelled_UpdateError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
|
||||||
repo.updateErr = fmt.Errorf("update failed")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.MarkCancelled(context.Background(), 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== GetByID ====================
|
|
||||||
|
|
||||||
func TestGetByID_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), gen.ID)
|
|
||||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetByID_WrongUser(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, gen)
|
|
||||||
require.Contains(t, err.Error(), "无权访问")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetByID_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, err := svc.GetByID(context.Background(), 999, 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== List ====================
|
|
||||||
|
|
||||||
func TestList_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
|
|
||||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
|
|
||||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, gens, 2) // 只有 userID=1 的
|
|
||||||
require.Equal(t, int64(2), total)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestList_DefaultPagination(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
|
|
||||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestList_MaxPageSize(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
// pageSize > 100 → 应限制为 100
|
|
||||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestList_Error(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.listErr = fmt.Errorf("db error")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Delete ====================
|
|
||||||
|
|
||||||
func TestDelete_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, exists := repo.gens[1]
|
|
||||||
require.False(t, exists)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_WrongUser(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "无权删除")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_NotFound(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 999, 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err) // s3Storage 为 nil,跳过清理
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err) // quotaService 为 nil,跳过释放
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_NonS3NoCleanup(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_DeleteRepoError(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
|
|
||||||
repo.deleteErr = fmt.Errorf("delete failed")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== CountActiveByUser ====================
|
|
||||||
|
|
||||||
func TestCountActiveByUser_Success(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
|
||||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
|
|
||||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(2), count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCountActiveByUser_NoActive(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(0), count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCountActiveByUser_Error(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.countErr = fmt.Errorf("db error")
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
_, err := svc.CountActiveByUser(context.Background(), 1)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== ResolveMediaURLs ====================
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_NilGen(t *testing.T) {
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
|
||||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_NonS3(t *testing.T) {
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
|
||||||
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
|
|
||||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
|
||||||
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
|
||||||
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
|
||||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_Local(t *testing.T) {
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
|
||||||
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
|
|
||||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
|
||||||
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 状态流转完整测试 ====================
|
|
||||||
|
|
||||||
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
// 1. 创建 pending
|
|
||||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
|
||||||
|
|
||||||
// 2. 标记 generating
|
|
||||||
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
|
|
||||||
|
|
||||||
// 3. 标记 completed
|
|
||||||
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
|
||||||
|
|
||||||
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
|
|
||||||
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
|
||||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 权限隔离测试 ====================
|
|
||||||
|
|
||||||
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
|
|
||||||
// 用户 2 尝试访问用户 1 的记录
|
|
||||||
_, err := svc.GetByID(context.Background(), gen.ID, 2)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "无权访问")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
svc := NewSoraGenerationService(repo, nil, nil)
|
|
||||||
|
|
||||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), gen.ID, 2)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "无权删除")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Delete: S3 清理 + 配额释放路径 ====================
|
|
||||||
|
|
||||||
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
|
|
||||||
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
|
|
||||||
// 验证 Delete 仍然成功(S3 错误只是记录日志)
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{
|
|
||||||
ID: 1, UserID: 1,
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
|
|
||||||
}
|
|
||||||
s3Storage := newS3StorageFailingDelete()
|
|
||||||
svc := NewSoraGenerationService(repo, s3Storage, nil)
|
|
||||||
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err) // S3 清理失败不影响删除
|
|
||||||
_, exists := repo.gens[1]
|
|
||||||
require.False(t, exists)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
|
|
||||||
// 有配额服务时,删除 S3 类型记录会释放配额
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{
|
|
||||||
ID: 1, UserID: 1,
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
FileSizeBytes: 1048576, // 1MB
|
|
||||||
}
|
|
||||||
|
|
||||||
userRepo := newStubUserRepoForQuota()
|
|
||||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
|
|
||||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
|
||||||
|
|
||||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
// 配额应被释放: 2MB - 1MB = 1MB
|
|
||||||
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
|
|
||||||
// S3 清理 + 配额释放同时触发
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{
|
|
||||||
ID: 1, UserID: 1,
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{"key1"},
|
|
||||||
FileSizeBytes: 512,
|
|
||||||
}
|
|
||||||
|
|
||||||
userRepo := newStubUserRepoForQuota()
|
|
||||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
|
||||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
|
||||||
s3Storage := newS3StorageFailingDelete()
|
|
||||||
|
|
||||||
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, exists := repo.gens[1]
|
|
||||||
require.False(t, exists)
|
|
||||||
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
|
|
||||||
// 本地存储同样需要释放配额
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{
|
|
||||||
ID: 1, UserID: 1,
|
|
||||||
StorageType: SoraStorageTypeLocal,
|
|
||||||
FileSizeBytes: 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
userRepo := newStubUserRepoForQuota()
|
|
||||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
|
|
||||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
|
||||||
|
|
||||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
|
|
||||||
// FileSizeBytes=0 跳过配额释放
|
|
||||||
repo := newStubGenRepo()
|
|
||||||
repo.gens[1] = &SoraGeneration{
|
|
||||||
ID: 1, UserID: 1,
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
FileSizeBytes: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
userRepo := newStubUserRepoForQuota()
|
|
||||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
|
||||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
|
||||||
|
|
||||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
|
||||||
err := svc.Delete(context.Background(), 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
|
|
||||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
|
||||||
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
|
||||||
MediaURL: "original",
|
|
||||||
}
|
|
||||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
|
|
||||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
|
||||||
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{
|
|
||||||
"sora/1/2024/01/01/img1.png",
|
|
||||||
"sora/1/2024/01/01/img2.png",
|
|
||||||
"sora/1/2024/01/01/img3.png",
|
|
||||||
},
|
|
||||||
MediaURL: "original",
|
|
||||||
}
|
|
||||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
|
||||||
require.NoError(t, err)
|
|
||||||
// 主 URL 更新为第一个 key 的 CDN URL
|
|
||||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
|
|
||||||
// 多图 URLs 全部更新
|
|
||||||
require.Len(t, gen.MediaURLs, 3)
|
|
||||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
|
|
||||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
|
|
||||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
|
|
||||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
|
||||||
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{},
|
|
||||||
MediaURL: "original",
|
|
||||||
}
|
|
||||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "original", gen.MediaURL) // 不变
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
|
|
||||||
// 使用无 settingService 的 S3 Storage,getClient 会失败
|
|
||||||
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
|
||||||
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
|
||||||
MediaURL: "original",
|
|
||||||
}
|
|
||||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
|
||||||
require.Error(t, err) // GetAccessURL 失败应传播错误
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
|
|
||||||
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
|
|
||||||
s3Storage := newS3StorageFailingDelete()
|
|
||||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
|
||||||
|
|
||||||
gen := &SoraGeneration{
|
|
||||||
StorageType: SoraStorageTypeS3,
|
|
||||||
S3ObjectKeys: []string{
|
|
||||||
"sora/1/2024/01/01/img1.png",
|
|
||||||
"sora/1/2024/01/01/img2.png",
|
|
||||||
},
|
|
||||||
MediaURL: "original",
|
|
||||||
}
|
|
||||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
|
||||||
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
|
|
||||||
}
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
||||||
"github.com/robfig/cron/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
|
||||||
|
|
||||||
// SoraMediaCleanupService 定期清理本地媒体文件
|
|
||||||
type SoraMediaCleanupService struct {
|
|
||||||
storage *SoraMediaStorage
|
|
||||||
cfg *config.Config
|
|
||||||
|
|
||||||
cron *cron.Cron
|
|
||||||
|
|
||||||
startOnce sync.Once
|
|
||||||
stopOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
|
||||||
return &SoraMediaCleanupService{
|
|
||||||
storage: storage,
|
|
||||||
cfg: cfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraMediaCleanupService) Start() {
|
|
||||||
if s == nil || s.cfg == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !s.cfg.Sora.Storage.Cleanup.Enabled {
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (disabled)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.storage == nil || !s.storage.Enabled() {
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (storage disabled)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.startOnce.Do(func() {
|
|
||||||
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
|
|
||||||
if schedule == "" {
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (empty schedule)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
loc := time.Local
|
|
||||||
if strings.TrimSpace(s.cfg.Timezone) != "" {
|
|
||||||
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
|
|
||||||
loc = parsed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
|
|
||||||
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.cron = c
|
|
||||||
s.cron.Start()
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraMediaCleanupService) Stop() {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.stopOnce.Do(func() {
|
|
||||||
if s.cron != nil {
|
|
||||||
ctx := s.cron.Stop()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cron stop timed out")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraMediaCleanupService) runCleanup() {
|
|
||||||
if s.cfg == nil || s.storage == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
|
||||||
if retention <= 0 {
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] skipped (retention_days=%d)", retention)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -retention)
|
|
||||||
deleted := 0
|
|
||||||
|
|
||||||
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
|
|
||||||
for _, root := range roots {
|
|
||||||
if root == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if info.IsDir() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if info.ModTime().Before(cutoff) {
|
|
||||||
if rmErr := os.Remove(p); rmErr == nil {
|
|
||||||
deleted++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
logger.LegacyPrintf("service.sora_media_cleanup", "[SoraCleanup] cleanup finished, deleted=%d", deleted)
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user