diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 7eabde62..9386678d 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 8b060688..1fcba8fa 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.117 +0.1.118 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 4369d980..710ebb58 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + affiliateRepository := repository.NewAffiliateRepository(client, db) + affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) @@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { announcementReadRepository := repository.NewAnnouncementReadRepository(client) announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository) announcementHandler := handler.NewAnnouncementHandler(announcementService) + channelMonitorRepository := repository.NewChannelMonitorRepository(client, db) + channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor) + channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db) dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig) @@ -196,7 +201,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) @@ -225,25 +230,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) - sqlDB, err := repository.ProvideSQLDB(client) - if err != nil { - return nil, err - } - channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB) - channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB) + channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) + channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) - channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor) - channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) - channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) - channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService) windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository) windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService) - availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler) + affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -253,9 +250,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) + availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -273,6 +271,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) +<<<<<<< HEAD + channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService) application := &Application{ Server: httpServer, diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index a915f7ce..4dcfaa6b 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -653,6 +653,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { type TestAccountRequest struct { ModelID string `json:"model_id"` Prompt string `json:"prompt"` + Mode string `json:"mode"` } type SyncFromCRSRequest struct { @@ -683,7 +684,7 @@ func (h *AccountHandler) Test(c *gin.Context) { _ = c.ShouldBindJSON(&req) // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil { // Error already sent via SSE, just log return } diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go new file mode 100644 index 00000000..97e649ec --- /dev/null +++ b/backend/internal/handler/admin/affiliate_handler.go @@ -0,0 +1,183 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AffiliateHandler handles admin affiliate (邀请返利) management: +// listing users with custom settings, updating per-user invite codes +// and exclusive rebate rates, and batch operations. +type AffiliateHandler struct { + affiliateService *service.AffiliateService + adminService service.AdminService +} + +// NewAffiliateHandler creates a new admin affiliate handler. +func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler { + return &AffiliateHandler{ + affiliateService: affiliateService, + adminService: adminService, + } +} + +// ListUsers returns paginated users with custom affiliate settings. +// GET /api/v1/admin/affiliates/users +func (h *AffiliateHandler) ListUsers(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + search := c.Query("search") + + entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{ + Search: search, + Page: page, + PageSize: pageSize, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, entries, total, page, pageSize) +} + +// UpdateUserSettings updates a user's affiliate settings. +// PUT /api/v1/admin/affiliates/users/:user_id +// +// Both fields are optional and applied independently. +type UpdateAffiliateUserRequest struct { + AffCode *string `json:"aff_code"` + AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"` + // ClearRebateRate explicitly clears the per-user rate (sets it to NULL). + // Used to disambiguate from "field not provided". + ClearRebateRate bool `json:"clear_rebate_rate"` +} + +func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + + var req UpdateAffiliateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.AffCode != nil { + if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil { + response.ErrorFrom(c, err) + return + } + } + + if req.ClearRebateRate { + if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil { + response.ErrorFrom(c, err) + return + } + } else if req.AffRebateRatePercent != nil { + if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil { + response.ErrorFrom(c, err) + return + } + } + + response.Success(c, gin.H{"user_id": userID}) +} + +// ClearUserSettings removes ALL of a user's custom affiliate settings — clears +// the exclusive rebate rate AND regenerates the invite code as a new system +// random one. Conceptually this "removes the user from the custom list". +// +// Both writes happen in this handler; failure of one leaves the other applied, +// but the operation is idempotent so the admin can re-run it safely. +// DELETE /api/v1/admin/affiliates/users/:user_id +func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil { + response.ErrorFrom(c, err) + return + } + if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"user_id": userID}) +} + +// BatchSetRate applies the same rebate rate (or clears it) to multiple users. +// +// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is +// ignored). Otherwise aff_rebate_rate_percent is required and applied to +// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal +// can't distinguish a missing field from `null`, and a silent clear from a +// frontend that forgot to include the rate would be a footgun. +// +// POST /api/v1/admin/affiliates/users/batch-rate +type BatchSetRateRequest struct { + UserIDs []int64 `json:"user_ids" binding:"required"` + AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"` + Clear bool `json:"clear"` +} + +func (h *AffiliateHandler) BatchSetRate(c *gin.Context) { + var req BatchSetRateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.UserIDs) == 0 { + response.BadRequest(c, "user_ids cannot be empty") + return + } + if !req.Clear && req.AffRebateRatePercent == nil { + response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true") + return + } + rate := req.AffRebateRatePercent + if req.Clear { + rate = nil + } + if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"affected": len(req.UserIDs)}) +} + +// AffiliateUserSummary is the minimal user shape returned by LookupUsers, +// shared with the frontend's add-custom-user picker. +type AffiliateUserSummary struct { + ID int64 `json:"id"` + Email string `json:"email"` + Username string `json:"username"` +} + +// LookupUsers searches users by email/username for the "add custom user" modal. +// GET /api/v1/admin/affiliates/users/lookup?q= +func (h *AffiliateHandler) LookupUsers(c *gin.Context) { + keyword := c.Query("q") + if keyword == "" { + response.Success(c, []AffiliateUserSummary{}) + return + } + users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc") + if err != nil { + response.ErrorFrom(c, err) + return + } + result := make([]AffiliateUserSummary, len(users)) + for i, u := range users { + result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username} + } + response.Success(c, result) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 4277f0f1..40bf1c69 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + AffiliateRebateRate: settings.AffiliateRebateRate, DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, @@ -241,6 +242,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, + + AffiliateEnabled: settings.AffiliateEnabled, } response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } @@ -338,6 +341,7 @@ type UpdateSettingsRequest struct { // 默认配置 DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` @@ -439,6 +443,9 @@ type UpdateSettingsRequest struct { // Available Channels feature switch (user-facing) AvailableChannelsEnabled *bool `json:"available_channels_enabled"` + + // Affiliate (邀请返利) feature switch + AffiliateEnabled *bool `json:"affiliate_enabled"` } // UpdateSettings 更新系统设置 @@ -468,6 +475,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.DefaultBalance < 0 { req.DefaultBalance = 0 } + affiliateRebateRate := previousSettings.AffiliateRebateRate + if req.AffiliateRebateRate != nil { + affiliateRebateRate = *req.AffiliateRebateRate + } + if affiliateRebateRate < service.AffiliateRebateRateMin { + affiliateRebateRate = service.AffiliateRebateRateMin + } + if affiliateRebateRate > service.AffiliateRebateRateMax { + affiliateRebateRate = service.AffiliateRebateRateMax + } // 通用表格配置:兼容旧客户端未传字段时保留当前值。 if req.TableDefaultPageSize <= 0 { req.TableDefaultPageSize = previousSettings.TableDefaultPageSize @@ -1119,6 +1136,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + AffiliateRebateRate: affiliateRebateRate, DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, @@ -1252,6 +1270,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AvailableChannelsEnabled }(), + AffiliateEnabled: func() bool { + if req.AffiliateEnabled != nil { + return *req.AffiliateEnabled + } + return previousSettings.AffiliateEnabled + }(), } authSourceDefaults := &service.AuthSourceDefaultSettings{ @@ -1433,6 +1457,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + AffiliateRebateRate: updatedSettings.AffiliateRebateRate, DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -1488,6 +1513,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, + + AffiliateEnabled: updatedSettings.AffiliateEnabled, } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } @@ -1738,6 +1765,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } + if before.AffiliateRebateRate != after.AffiliateRebateRate { + changed = append(changed, "affiliate_rebate_rate") + } if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { changed = append(changed, "default_subscriptions") } @@ -1853,6 +1883,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled { changed = append(changed, "available_channels_enabled") } + if before.AffiliateEnabled != after.AffiliateEnabled { + changed = append(changed, "affiliate_enabled") + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index dc68a466..1f9a66ff 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -48,6 +48,7 @@ type RegisterRequest struct { TurnstileToken string `json:"turnstile_token"` PromoCode string `json:"promo_code"` // 注册优惠码 InvitationCode string `json:"invitation_code"` // 邀请码 + AffCode string `json:"aff_code"` // 邀请返利码 } // SendVerifyCodeRequest 发送验证码请求 @@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) + _, user, err := h.authService.RegisterWithVerification( + c.Request.Context(), + req.Email, + req.Password, + req.VerifyCode, + req.PromoCode, + req.InvitationCode, + req.AffCode, + ) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index a4b7a297..ffe9ff5f 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars ( nil, nil, options.defaultSubAssigner, + nil, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go index 1924cb81..f1c6d87d 100644 --- a/backend/internal/handler/auth_session_revocation_test.go +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) handler := &AuthHandler{authService: authService} recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 7cf114c1..b3c7786d 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, nil, nil, nil, + nil, ) return &AuthHandler{ diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 2affbc46..051fab18 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -108,6 +108,7 @@ type SystemSettings struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` @@ -191,6 +192,9 @@ type SystemSettings struct { // Available Channels feature switch (user-facing aggregate view) AvailableChannelsEnabled bool `json:"available_channels_enabled"` + + // Affiliate (邀请返利) feature switch + AffiliateEnabled bool `json:"affiliate_enabled"` } type DefaultSubscriptionSetting struct { @@ -243,6 +247,8 @@ type PublicSettings struct { ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` AvailableChannelsEnabled bool `json:"available_channels_enabled"` + + AffiliateEnabled bool `json:"affiliate_enabled"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 9787897e..906ab95f 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -35,6 +35,7 @@ type AdminHandlers struct { ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler Payment *admin.PaymentHandler Windsurf *admin.WindsurfHandler + Affiliate *admin.AffiliateHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 3c4e6251..f395970a 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -130,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + false, ) if err != nil { reqLog.Warn("openai_chat_completions.account_select_failed", @@ -153,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { defaultModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + false, ) if err == nil && selection != nil { c.Set("openai_chat_completions_fallback_model", defaultModel) diff --git a/backend/internal/handler/openai_gateway_compact_log_test.go b/backend/internal/handler/openai_gateway_compact_log_test.go index 062f318b..e18509b4 100644 --- a/backend/internal/handler/openai_gateway_compact_log_test.go +++ b/backend/internal/handler/openai_gateway_compact_log_test.go @@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0") c.Set(opsModelKey, "gpt-5.3-codex") c.Set(opsAccountIDKey, int64(123)) c.Header("x-request-id", "rid-compact-ok") @@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0") c.Status(http.StatusBadGateway) h := &OpenAIGatewayHandler{} @@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) { c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`)) c.Request.Header.Set("Content-Type", "application/json") - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0") h := &OpenAIGatewayHandler{} h.Responses(c) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 1c975573..7676ffa3 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -238,6 +238,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Generate session hash (header first; fallback to prompt_cache_key) sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody) + requireCompact := isOpenAIRemoteCompactPath(c) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -256,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + requireCompact, ) if err != nil { reqLog.Warn("openai.account_select_failed", @@ -263,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { + if errors.Is(err, service.ErrNoAvailableCompactAccounts) { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted) + return + } h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } @@ -644,6 +650,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { currentRoutingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + false, ) if err != nil { reqLog.Warn("openai_messages.account_select_failed", @@ -1167,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { reqModel, nil, service.OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) if err != nil { reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go index a7bc4ba3..377f432e 100644 --- a/backend/internal/handler/payment_handler_resume_test.go +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { Save(context.Background()) require.NoError(t, err) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing require.NoError(t, err) configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t require.NoError(t, err) configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) t.Cleanup(func() { _ = client.Close() }) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 96964de4..22f2aa15 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, + + AffiliateEnabled: settings.AffiliateEnabled, }) } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index f74c2b72..3f6ed8c2 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -14,10 +14,11 @@ import ( // UserHandler handles user-related requests type UserHandler struct { - userService *service.UserService - authService *service.AuthService - emailService *service.EmailService - emailCache service.EmailCache + userService *service.UserService + authService *service.AuthService + emailService *service.EmailService + emailCache service.EmailCache + affiliateService *service.AffiliateService } // NewUserHandler creates a new UserHandler @@ -26,12 +27,14 @@ func NewUserHandler( authService *service.AuthService, emailService *service.EmailService, emailCache service.EmailCache, + affiliateService *service.AffiliateService, ) *UserHandler { return &UserHandler{ - userService: userService, - authService: authService, - emailService: emailService, - emailCache: emailCache, + userService: userService, + authService: authService, + emailService: emailService, + emailCache: emailCache, + affiliateService: affiliateService, } } @@ -159,6 +162,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { response.Success(c, profileResp) } +// GetAffiliate returns the current user's affiliate details. +// GET /api/v1/user/aff +func (h *UserHandler) GetAffiliate(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, detail) +} + +// TransferAffiliateQuota transfers all available affiliate quota into current balance. +// POST /api/v1/user/aff/transfer +func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "transferred_quota": transferred, + "balance": balance, + }) +} + type StartIdentityBindingRequest struct { Provider string `json:"provider" binding:"required"` RedirectTo string `json:"redirect_to"` diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 452dee09..8a864b51 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) recorder := httptest.NewRecorder() @@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -270,20 +270,20 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { AvatarURL: "https://cdn.example.com/linuxdo.png", AvatarSource: "remote_url", }, - identities: []service.UserAuthIdentityRecord{ - { - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: "linuxdo-subject-21", - VerifiedAt: &verifiedAt, - Metadata: map[string]any{ - "username": "linuxdo-handle", - "avatar_url": "https://cdn.example.com/linuxdo.png", + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, }, }, - }, - } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) recorder := httptest.NewRecorder() @@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t * ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) recorder := httptest.NewRecorder() @@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 57733ba7..d440d115 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -38,6 +38,7 @@ func ProvideAdminHandlers( channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, paymentHandler *admin.PaymentHandler, windsurfHandler *admin.WindsurfHandler, + affiliateHandler *admin.AffiliateHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -69,6 +70,7 @@ func ProvideAdminHandlers( ChannelMonitorTemplate: channelMonitorTemplateHandler, Payment: paymentHandler, Windsurf: windsurfHandler, + Affiliate: affiliateHandler, } } @@ -179,6 +181,7 @@ var ProviderSet = wire.NewSet( admin.NewChannelMonitorHandler, admin.NewChannelMonitorRequestTemplateHandler, admin.NewPaymentHandler, + admin.NewAffiliateHandler, // Windsurf handler ProvideWindsurfHandler, diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go index f0a5b07e..49426b88 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool { var out []AnthropicTool for _, t := range tools { switch t.Type { - case "web_search": + case "web_search", "google_search", "web_search_20250305": out = append(out, AnthropicTool{ Type: "web_search_20250305", Name: "web_search", diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index e0d1a53e..f8c6b75f 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -12,17 +12,23 @@ import "encoding/json" // AnthropicRequest is the request body for POST /v1/messages. type AnthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock - Messages []AnthropicMessage `json:"messages"` - Tools []AnthropicTool `json:"tools,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - StopSeqs []string `json:"stop_sequences,omitempty"` - Thinking *AnthropicThinking `json:"thinking,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock + Messages []AnthropicMessage `json:"messages"` + Tools []AnthropicTool `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + StopSeqs []string `json:"stop_sequences,omitempty"` + Thinking *AnthropicThinking `json:"thinking,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + // Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id + // 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化 + // 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/ + // RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的 + // user_id,进而导致请求被归类为第三方 app。 + Metadata json.RawMessage `json:"metadata,omitempty"` OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"` } @@ -76,10 +82,18 @@ type AnthropicImageSource struct { // AnthropicTool describes a tool available to the model. type AnthropicTool struct { - Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools - Name string `json:"name"` - Description string `json:"description,omitempty"` - InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object + Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object + CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"` +} + +// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。 +// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。 +type AnthropicCacheControl struct { + Type string `json:"type"` // "ephemeral" + TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m(由 Anthropic 判定) } // AnthropicResponse is the non-streaming response from POST /v1/messages. diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index b68c5ccd..ec3c2e6b 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -210,6 +210,12 @@ func AttributionHeaderDisabled() bool { } // Beta header 常量 +// +// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04)。 +// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致, +// 原因:Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源; +// 缺少任何"官方 Claude Code 请求才会带"的 beta,都会被降级到第三方额度, +// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.` const ( BetaOAuth = "oauth-2025-04-20" BetaClaudeCode = "claude-code-20250219" @@ -218,15 +224,17 @@ const ( BetaTokenCounting = "token-counting-2024-11-01" BetaContext1M = "context-1m-2025-08-07" BetaFastMode = "fast-mode-2026-02-01" - BetaRedactThinking = "redact-thinking-2026-02-12" - BetaContextManagement = "context-management-2025-06-27" - BetaPromptCachingScope = "prompt-caching-scope-2026-01-05" - BetaEffort = "effort-2025-11-24" - BetaTaskBudgets = "task-budgets-2026-03-13" - BetaTokenEfficientTools = "token-efficient-tools-2026-03-28" - BetaStructuredOutputs = "structured-outputs-2025-12-15" - BetaAdvisor = "advisor-tool-2026-03-01" - BetaWebSearch = "web-search-2025-03-05" + // 新增(对齐官方 CLI 2.1.9x 以来的流量) + BetaPromptCachingScope = "prompt-caching-scope-2026-01-05" + BetaEffort = "effort-2025-11-24" + BetaRedactThinking = "redact-thinking-2026-02-12" + BetaContextManagement = "context-management-2025-06-27" + BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11" + BetaTaskBudgets = "task-budgets-2026-03-13" + BetaTokenEfficientTools = "token-efficient-tools-2026-03-28" + BetaStructuredOutputs = "structured-outputs-2025-12-15" + BetaAdvisor = "advisor-tool-2026-03-01" + BetaWebSearch = "web-search-2025-03-05" ) // DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 @@ -291,32 +299,53 @@ func GetAPIKeyBetaHeader(modelID string) string { return APIKeyBetaHeader } -// DefaultHeaders 是 Claude Code 客户端默认请求头。 -var DefaultHeaders = buildDefaultHeaders(DefaultDeviceProfile()) +// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。 +// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先; +// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。 +const DefaultCacheControlTTL = "5m" -// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值) -// cliVersion: Claude CLI 版本(如 "2.1.81") -// pkgVersion: SDK 版本(如 "0.80.0") -// runtimeVersion: Node.js 版本(如 "v24.13.0") -// os_: 操作系统(如 "Linux") -// arch: 架构(如 "arm64") -func ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) { - if cliVersion != "" { - DefaultCLIVersion = strings.TrimSpace(cliVersion) +// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver)。 +// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。 +// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。 +const CLICurrentVersion = "2.1.92" + +// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表, +// 用于 OAuth 账号伪装成 Claude Code 时使用。 +// 顺序与真实 CLI 抓包一致。 +// +// 使用建议: +// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。 +// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。 +// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。 +func FullClaudeCodeMimicryBetas() []string { + return []string{ + BetaClaudeCode, + BetaOAuth, + BetaInterleavedThinking, + BetaPromptCachingScope, + BetaEffort, + BetaRedactThinking, + BetaContextManagement, + BetaExtendedCacheTTL, } - if pkgVersion != "" { - DefaultStainlessPackageVersion = strings.TrimSpace(pkgVersion) - } - if runtimeVersion != "" { - DefaultStainlessRuntimeVersion = strings.TrimSpace(runtimeVersion) - } - if os_ != "" { - DefaultStainlessOS = strings.TrimSpace(os_) - } - if arch != "" { - DefaultStainlessArch = strings.TrimSpace(arch) - } - DefaultHeaders = buildDefaultHeaders(DefaultDeviceProfile()) +} + +// DefaultHeaders 是 Claude Code 客户端默认请求头。 +var DefaultHeaders = map[string]string{ + // Keep these in sync with recent Claude CLI traffic to reduce the chance + // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage. + // 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。 + "User-Agent": "claude-cli/2.1.92 (external, cli)", + "X-Stainless-Lang": "js", + "X-Stainless-Package-Version": "0.70.0", + "X-Stainless-OS": "Linux", + "X-Stainless-Arch": "arm64", + "X-Stainless-Runtime": "node", + "X-Stainless-Runtime-Version": "v24.13.0", + "X-Stainless-Retry-Count": "0", + "X-Stainless-Timeout": "600", + "X-App": "cli", + "Anthropic-Dangerous-Direct-Browser-Access": "true", } // Model 表示一个 Claude 模型 diff --git a/backend/internal/repository/account_repo_compact_extra_test.go b/backend/internal/repository/account_repo_compact_extra_test.go new file mode 100644 index 00000000..604f392e --- /dev/null +++ b/backend/internal/repository/account_repo_compact_extra_test.go @@ -0,0 +1,14 @@ +package repository + +import "testing" + +func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) { + updates := map[string]any{ + "openai_compact_supported": true, + "openai_compact_checked_at": "2026-04-10T10:00:00Z", + } + + if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + t.Fatalf("expected compact capability updates to enqueue scheduler outbox") + } +} diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go new file mode 100644 index 00000000..e3dd56b8 --- /dev/null +++ b/backend/internal/repository/affiliate_repo.go @@ -0,0 +1,664 @@ +package repository + +import ( + "context" + "crypto/rand" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +const ( + affiliateCodeLength = 12 + affiliateCodeMaxAttempts = 12 +) + +var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") + +type affiliateQueryExecer interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +type affiliateRepository struct { + client *dbent.Client +} + +func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository { + return &affiliateRepository{client: client} +} + +func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) { + if userID <= 0 { + return nil, service.ErrUserNotFound + } + client := clientFromContext(ctx, r.client) + return ensureUserAffiliateWithClient(ctx, client, userID) +} + +func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) { + client := clientFromContext(ctx, r.client) + return queryAffiliateByCode(ctx, client, code) +} + +func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) { + var bound bool + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil { + return err + } + + res, err := txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL", + inviterID, userID, + ) + if err != nil { + return fmt.Errorf("bind inviter: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + bound = false + return nil + } + + if _, err = txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1", + inviterID, + ); err != nil { + return fmt.Errorf("increment inviter aff_count: %w", err) + } + bound = true + return nil + }) + if err != nil { + return false, err + } + return bound, nil +} + +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) { + if amount <= 0 { + return false, nil + } + + var applied bool + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + res, err := txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2", + amount, inviterID, + ) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + applied = false + return nil + } + + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } + + applied = true + return nil + }) + if err != nil { + return false, err + } + return applied, nil +} + +func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) { + var transferred float64 + var newBalance float64 + + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + + rows, err := txClient.QueryContext(txCtx, ` +WITH claimed AS ( + SELECT aff_quota::double precision AS amount + FROM user_affiliates + WHERE user_id = $1 + AND aff_quota > 0 + FOR UPDATE +), +cleared AS ( + UPDATE user_affiliates ua + SET aff_quota = 0, + updated_at = NOW() + FROM claimed c + WHERE ua.user_id = $1 + RETURNING c.amount +) +SELECT amount +FROM cleared`, userID) + if err != nil { + return fmt.Errorf("claim affiliate quota: %w", err) + } + + if !rows.Next() { + _ = rows.Close() + if err := rows.Err(); err != nil { + return err + } + return service.ErrAffiliateQuotaEmpty + } + if err := rows.Scan(&transferred); err != nil { + _ = rows.Close() + return err + } + if err := rows.Close(); err != nil { + return err + } + if transferred <= 0 { + return service.ErrAffiliateQuotaEmpty + } + + affected, err := txClient.User.Update(). + Where(user.IDEQ(userID)). + AddBalance(transferred). + AddTotalRecharged(transferred). + Save(txCtx) + if err != nil { + return fmt.Errorf("credit user balance by affiliate quota: %w", err) + } + if affected == 0 { + return service.ErrUserNotFound + } + + newBalance, err = queryUserBalance(txCtx, txClient, userID) + if err != nil { + return err + } + + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) +VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { + return fmt.Errorf("insert affiliate transfer ledger: %w", err) + } + + return nil + }) + if err != nil { + return 0, 0, err + } + + return transferred, newBalance, nil +} + +func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) { + if limit <= 0 { + limit = 100 + } + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.created_at +FROM user_affiliates ua +LEFT JOIN users u ON u.id = ua.user_id +WHERE ua.inviter_id = $1 +ORDER BY ua.created_at DESC +LIMIT $2`, inviterID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + invitees := make([]service.AffiliateInvitee, 0) + for rows.Next() { + var item service.AffiliateInvitee + var createdAt time.Time + if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil { + return nil, err + } + item.CreatedAt = &createdAt + invitees = append(invitees, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return invitees, nil +} + +func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx, tx.Client()) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return fmt.Errorf("begin affiliate transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx, tx.Client()); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit affiliate transaction: %w", err) + } + return nil +} + +func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) { + summary, err := queryAffiliateByUserID(ctx, client, userID) + if err == nil { + return summary, nil + } + if !errors.Is(err, service.ErrAffiliateProfileNotFound) { + return nil, err + } + + for i := 0; i < affiliateCodeMaxAttempts; i++ { + code, codeErr := generateAffiliateCode() + if codeErr != nil { + return nil, codeErr + } + _, insertErr := client.ExecContext(ctx, ` +INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at) +VALUES ($1, $2, NOW(), NOW()) +ON CONFLICT (user_id) DO NOTHING`, userID, code) + if insertErr == nil { + break + } + if isAffiliateUniqueViolation(insertErr) { + continue + } + return nil, insertErr + } + + return queryAffiliateByUserID(ctx, client, userID) +} + +func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) { + rows, err := client.QueryContext(ctx, ` +SELECT user_id, + aff_code, + aff_code_custom, + aff_rebate_rate_percent, + inviter_id, + aff_count, + aff_quota::double precision, + aff_history_quota::double precision, + created_at, + updated_at +FROM user_affiliates +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrAffiliateProfileNotFound + } + + var out service.AffiliateSummary + var inviterID sql.NullInt64 + var rebateRate sql.NullFloat64 + if err := rows.Scan( + &out.UserID, + &out.AffCode, + &out.AffCodeCustom, + &rebateRate, + &inviterID, + &out.AffCount, + &out.AffQuota, + &out.AffHistoryQuota, + &out.CreatedAt, + &out.UpdatedAt, + ); err != nil { + return nil, err + } + if inviterID.Valid { + out.InviterID = &inviterID.Int64 + } + if rebateRate.Valid { + v := rebateRate.Float64 + out.AffRebateRatePercent = &v + } + return &out, nil +} + +func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) { + rows, err := client.QueryContext(ctx, ` +SELECT user_id, + aff_code, + aff_code_custom, + aff_rebate_rate_percent, + inviter_id, + aff_count, + aff_quota::double precision, + aff_history_quota::double precision, + created_at, + updated_at +FROM user_affiliates +WHERE aff_code = $1 +LIMIT 1`, strings.ToUpper(strings.TrimSpace(code))) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrAffiliateProfileNotFound + } + + var out service.AffiliateSummary + var inviterID sql.NullInt64 + var rebateRate sql.NullFloat64 + if err := rows.Scan( + &out.UserID, + &out.AffCode, + &out.AffCodeCustom, + &rebateRate, + &inviterID, + &out.AffCount, + &out.AffQuota, + &out.AffHistoryQuota, + &out.CreatedAt, + &out.UpdatedAt, + ); err != nil { + return nil, err + } + if inviterID.Valid { + out.InviterID = &inviterID.Int64 + } + if rebateRate.Valid { + v := rebateRate.Float64 + out.AffRebateRatePercent = &v + } + return &out, nil +} + +func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) { + rows, err := client.QueryContext(ctx, + "SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1", + userID, + ) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return 0, err + } + return 0, service.ErrUserNotFound + } + var balance float64 + if err := rows.Scan(&balance); err != nil { + return 0, err + } + return balance, nil +} + +func generateAffiliateCode() (string, error) { + buf := make([]byte, affiliateCodeLength) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("generate affiliate code: %w", err) + } + for i := range buf { + buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)] + } + return string(buf), nil +} + +func isAffiliateUniqueViolation(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) { + return string(pqErr.Code) == "23505" + } + return false +} + +// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。 +// 唯一性冲突返回 ErrAffiliateCodeTaken。 +func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error { + if userID <= 0 { + return service.ErrUserNotFound + } + code := strings.ToUpper(strings.TrimSpace(newCode)) + if code == "" { + return service.ErrAffiliateCodeInvalid + } + + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + res, err := txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_code = $1, + aff_code_custom = true, + updated_at = NOW() +WHERE user_id = $2`, code, userID) + if err != nil { + if isAffiliateUniqueViolation(err) { + return service.ErrAffiliateCodeTaken + } + return fmt.Errorf("update aff_code: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return service.ErrUserNotFound + } + return nil + }) +} + +// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。 +func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) { + if userID <= 0 { + return "", service.ErrUserNotFound + } + var newCode string + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + for i := 0; i < affiliateCodeMaxAttempts; i++ { + candidate, codeErr := generateAffiliateCode() + if codeErr != nil { + return codeErr + } + res, err := txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_code = $1, + aff_code_custom = false, + updated_at = NOW() +WHERE user_id = $2`, candidate, userID) + if err != nil { + if isAffiliateUniqueViolation(err) { + continue + } + return fmt.Errorf("reset aff_code: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return service.ErrUserNotFound + } + newCode = candidate + return nil + } + return fmt.Errorf("reset aff_code: exhausted attempts") + }) + if err != nil { + return "", err + } + return newCode, nil +} + +// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。 +func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error { + if userID <= 0 { + return service.ErrUserNotFound + } + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + // nullableArg lets us use a single UPDATE for both "set value" and + // "clear" cases — database/sql converts nil interface{} to SQL NULL. + res, err := txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_rebate_rate_percent = $1, + updated_at = NOW() +WHERE user_id = $2`, nullableArg(ratePercent), userID) + if err != nil { + return fmt.Errorf("set aff_rebate_rate_percent: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + return service.ErrUserNotFound + } + return nil + }) +} + +// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。 +func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error { + if len(userIDs) == 0 { + return nil + } + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + for _, uid := range userIDs { + if uid <= 0 { + continue + } + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil { + return err + } + } + _, err := txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_rebate_rate_percent = $1, + updated_at = NOW() +WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs)) + if err != nil { + return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err) + } + return nil + }) +} + +// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter +// binding: nil pointer → SQL NULL, non-nil → the float value. +func nullableArg(v *float64) any { + if v == nil { + return nil + } + return *v +} + +// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 +// +// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索": +// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。 +// 这避免了为两种情况维护两份 SQL 模板。 +func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) { + page := filter.Page + if page < 1 { + page = 1 + } + pageSize := filter.PageSize + if pageSize <= 0 || pageSize > 200 { + pageSize = 20 + } + offset := (page - 1) * pageSize + likePattern := "%" + strings.TrimSpace(filter.Search) + "%" + + const baseFrom = ` +FROM user_affiliates ua +JOIN users u ON u.id = ua.user_id +WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL) + AND (u.email ILIKE $1 OR u.username ILIKE $1)` + + client := clientFromContext(ctx, r.client) + + total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern) + if err != nil { + return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err) + } + + listQuery := ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.aff_code, + ua.aff_code_custom, + ua.aff_rebate_rate_percent, + ua.aff_count` + baseFrom + ` +ORDER BY ua.updated_at DESC +LIMIT $2 OFFSET $3` + + rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset) + if err != nil { + return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err) + } + defer func() { _ = rows.Close() }() + + entries := make([]service.AffiliateAdminEntry, 0) + for rows.Next() { + var e service.AffiliateAdminEntry + var rebate sql.NullFloat64 + if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode, + &e.AffCodeCustom, &rebate, &e.AffCount); err != nil { + return nil, 0, err + } + if rebate.Valid { + v := rebate.Float64 + e.AffRebateRatePercent = &v + } + entries = append(entries, e) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return entries, total, nil +} + +// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT). +func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) { + rows, err := client.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return 0, err + } + return 0, nil + } + var v int64 + if err := rows.Scan(&v); err != nil { + return 0, err + } + return v, nil +} diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go new file mode 100644 index 00000000..369f57cf --- /dev/null +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -0,0 +1,399 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 { + t.Helper() + rows, err := client.QueryContext(ctx, query, args...) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next(), "expected one row") + var value float64 + require.NoError(t, rows.Scan(&value)) + require.NoError(t, rows.Err()) + return value +} + +func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int { + t.Helper() + rows, err := client.QueryContext(ctx, query, args...) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next(), "expected one row") + var value int + require.NoError(t, rows.Scan(&value)) + require.NoError(t, rows.Err()) + return value +} + +func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 5.5, + Concurrency: 5, + }) + + affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000) + _, err := client.ExecContext(txCtx, ` +INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at) +VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34) + require.NoError(t, err) + + transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID) + require.NoError(t, err) + require.InDelta(t, 12.34, transferred, 1e-9) + require.InDelta(t, 17.84, balance, 1e-9) + + affQuota := querySingleFloat(t, txCtx, client, + "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID) + require.InDelta(t, 0.0, affQuota, 1e-9) + + persistedBalance := querySingleFloat(t, txCtx, client, + "SELECT balance::double precision FROM users WHERE id = $1", u.ID) + require.InDelta(t, 17.84, persistedBalance, 1e-9) + + ledgerCount := querySingleInt(t, txCtx, client, + "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) + require.Equal(t, 1, ledgerCount) +} + +// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the +// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx +// that already carries a transaction (via dbent.NewTxContext), repo.withTx +// must reuse that tx rather than opening a nested one. If this invariant +// breaks, AccrueQuota would commit independently and survive a rollback of +// the outer tx, which would violate payment_fulfillment's all-or-nothing +// semantics. +func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { + ctx := context.Background() + + outerTx, err := integrationEntClient.Tx(ctx) + require.NoError(t, err, "begin outer tx") + // Defensive cleanup: if any require.* below fires before the explicit + // Rollback, this prevents the tx from leaking until container teardown. + // Rollback is idempotent at the driver level (extra rollback returns an + // error we ignore). + t.Cleanup(func() { _ = outerTx.Rollback() }) + client := outerTx.Client() + txCtx := dbent.NewTxContext(ctx, outerTx) + + inviter := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 5, + }) + invitee := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 5, + }) + + repo := NewAffiliateRepository(client, integrationDB) + _, err = repo.EnsureUserAffiliate(txCtx, inviter.ID) + require.NoError(t, err) + _, err = repo.EnsureUserAffiliate(txCtx, invitee.ID) + require.NoError(t, err) + + bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID) + require.NoError(t, err) + require.True(t, bound, "invitee must bind to inviter") + + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5) + require.NoError(t, err) + require.True(t, applied, "AccrueQuota must report applied=true") + + // Visible inside the outer tx. + innerQuota := querySingleFloat(t, txCtx, client, + "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID) + require.InDelta(t, 3.5, innerQuota, 1e-9) + + // Roll back the outer tx; if AccrueQuota had opened its own inner tx and + // committed it, the rows would still be visible to the global client. + require.NoError(t, outerTx.Rollback()) + + rows, err := integrationEntClient.QueryContext(ctx, + "SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)", + inviter.ID, invitee.ID) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + require.True(t, rows.Next()) + var postRollbackCount int + require.NoError(t, rows.Scan(&postRollbackCount)) + require.Equal(t, 0, postRollbackCount, + "AccrueQuota must propagate the outer tx — found persisted rows after rollback") +} + +func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 3.21, + Concurrency: 5, + }) + + affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000) + _, err := client.ExecContext(txCtx, ` +INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at) +VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode) + require.NoError(t, err) + + transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID) + require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty) + require.InDelta(t, 0.0, transferred, 1e-9) + require.InDelta(t, 0.0, balance, 1e-9) + + persistedBalance := querySingleFloat(t, txCtx, client, + "SELECT balance::double precision FROM users WHERE id = $1", u.ID) + require.InDelta(t, 3.21, persistedBalance, 1e-9) +} + +// TestAffiliateRepository_AdminCustomCode covers the success path of admin +// invite-code rewrite + reset within a shared test transaction: +// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works +// - the old code can no longer be found +// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code +// +// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own +// test because a unique-violation aborts the surrounding Postgres tx, which +// would poison subsequent assertions in the same transaction. +func TestAffiliateRepository_AdminCustomCode(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + + original, err := repo.EnsureUserAffiliate(txCtx, u.ID) + require.NoError(t, err) + require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom") + originalCode := original.AffCode + + // Rewrite to a custom code + customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000) + require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode)) + + updated, err := repo.EnsureUserAffiliate(txCtx, u.ID) + require.NoError(t, err) + require.Equal(t, customCode, updated.AffCode) + require.True(t, updated.AffCodeCustom) + + // Lookup by new custom code finds the user + byCode, err := repo.GetAffiliateByCode(txCtx, customCode) + require.NoError(t, err) + require.Equal(t, u.ID, byCode.UserID) + + // Old system code should no longer match + _, err = repo.GetAffiliateByCode(txCtx, originalCode) + require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound) + + // Reset back to a fresh system code, clears custom flag + newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID) + require.NoError(t, err) + require.NotEqual(t, customCode, newSysCode) + + reset, err := repo.EnsureUserAffiliate(txCtx, u.ID) + require.NoError(t, err) + require.Equal(t, newSysCode, reset.AffCode) + require.False(t, reset.AffCodeCustom) + + // The old custom code is now free again + _, err = repo.GetAffiliateByCode(txCtx, customCode) + require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound) +} + +// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation +// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so +// this test must be the only assertion and run in its own tx — production +// callers each have their own outer tx, so this matches real behavior. +func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + taker := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, Status: service.StatusActive, + }) + requester := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, Status: service.StatusActive, + }) + + takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000) + require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode)) + + // Now requester tries to grab the same code → conflict. + err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode) + require.ErrorIs(t, err, service.ErrAffiliateCodeTaken) +} + +// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate +// set/clear and the Batch variant including NULL semantics. +func TestAffiliateRepository_AdminRebateRate(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u1 := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + u2 := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + + // Set exclusive rate for u1 + rate := 42.5 + require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate)) + + got, err := repo.EnsureUserAffiliate(txCtx, u1.ID) + require.NoError(t, err) + require.NotNil(t, got.AffRebateRatePercent) + require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9) + + // Clear exclusive rate + require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil)) + cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID) + require.NoError(t, err) + require.Nil(t, cleared.AffRebateRatePercent) + + // Batch set both users + batchRate := 15.0 + require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate)) + + for _, uid := range []int64{u1.ID, u2.ID} { + v, err := repo.EnsureUserAffiliate(txCtx, uid) + require.NoError(t, err) + require.NotNil(t, v.AffRebateRatePercent) + require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9) + } + + // Batch clear + require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil)) + for _, uid := range []int64{u1.ID, u2.ID} { + v, err := repo.EnsureUserAffiliate(txCtx, uid) + require.NoError(t, err) + require.Nil(t, v.AffRebateRatePercent) + } +} + +// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list +// only includes users with at least one override applied. +func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + // User without any custom config — should NOT appear in the list. + plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano()) + uPlain := mustCreateUser(t, client, &service.User{ + Email: plainEmail, PasswordHash: "hash", + Role: service.RoleUser, Status: service.StatusActive, + }) + _, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID) + require.NoError(t, err) + + // User with a custom code — should appear. + uCode := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, Status: service.StatusActive, + }) + require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000))) + + // User with only an exclusive rate — should appear. + uRate := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, Status: service.StatusActive, + }) + r := 33.3 + require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r)) + + entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{ + Page: 1, PageSize: 100, + }) + require.NoError(t, err) + + // Build a quick lookup to assert per-user attributes (other tests may have + // inserted custom rows in the same DB; we only care about our 3). + byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries)) + for _, e := range entries { + byUserID[e.UserID] = e + } + + require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear") + + codeEntry, ok := byUserID[uCode.ID] + require.True(t, ok, "custom-code user missing from list") + require.True(t, codeEntry.AffCodeCustom) + require.Nil(t, codeEntry.AffRebateRatePercent) + + rateEntry, ok := byUserID[uRate.ID] + require.True(t, ok, "custom-rate user missing from list") + require.False(t, rateEntry.AffCodeCustom) + require.NotNil(t, rateEntry.AffRebateRatePercent) + require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9) + + require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows") +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 6d24d312..f07bbb33 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet( NewChannelRepository, NewChannelMonitorRepository, NewChannelMonitorRequestTemplateRepository, + NewAffiliateRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index cb19f82c..d605c52b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -715,6 +715,7 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "affiliate_rebate_rate": 20, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -774,6 +775,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "affiliate_enabled": false, "wechat_connect_enabled": false, "wechat_connect_app_id": "", "wechat_connect_app_secret_configured": false, @@ -895,6 +897,7 @@ func TestAPIContracts(t *testing.T) { "custom_endpoints": [], "default_concurrency": 0, "default_balance": 0, + "affiliate_rebate_rate": 20, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -949,6 +952,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "affiliate_enabled": false, "wechat_connect_enabled": true, "wechat_connect_app_id": "wx-open-config", "wechat_connect_app_secret_configured": true, diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 06e3355e..dde92dfd 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index 84fd6967..a643d3bc 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) @@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) toucher := &recordingActivityToucher{} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 1ef5af4a..74109261 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -94,6 +94,9 @@ func RegisterAdminRoutes( // 渠道监控 registerChannelMonitorRoutes(admin, h) + + // 邀请返利(专属用户管理) + registerAffiliateRoutes(admin, h) } } @@ -615,3 +618,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) { templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply) } } + +// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置) +func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + affiliates := admin.Group("/affiliates") + { + users := affiliates.Group("/users") + { + users.GET("", h.Admin.Affiliate.ListUsers) + users.GET("/lookup", h.Admin.Affiliate.LookupUsers) + users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) + users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) + users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) + } + } +} diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index babab125..9976954c 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,8 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.GET("/aff", h.User.GetAffiliate) + user.POST("/aff/transfer", h.User.TransferAffiliateQuota) user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode) user.POST("/account-bindings/email", h.User.BindEmailIdentity) user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index edd5313f..cd06ffa3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -274,26 +274,6 @@ func (a *Account) GetCredentialAsInt64(key string) int64 { return 0 } -// GetCredentialAsBool 解析凭证中的 bool 字段,支持 bool 和 "true"/"false" 字符串 -func (a *Account) GetCredentialAsBool(key string) bool { - if a == nil || a.Credentials == nil { - return false - } - val, ok := a.Credentials[key] - if !ok || val == nil { - return false - } - switch v := val.(type) { - case bool: - return v - case string: - return strings.EqualFold(strings.TrimSpace(v), "true") - case float64: - return v != 0 - } - return false -} - func (a *Account) IsTempUnschedulableEnabled() bool { if a.Credentials == nil { return false @@ -413,6 +393,56 @@ func parseTempUnschedInt(value any) int { return 0 } +const ( + // OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility. + OpenAICompactModeAuto = "auto" + // OpenAICompactModeForceOn always treats the account as compact-supported. + OpenAICompactModeForceOn = "force_on" + // OpenAICompactModeForceOff always treats the account as compact-unsupported. + OpenAICompactModeForceOff = "force_off" +) + +func normalizeOpenAICompactMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case OpenAICompactModeForceOn: + return OpenAICompactModeForceOn + case OpenAICompactModeForceOff: + return OpenAICompactModeForceOff + default: + return OpenAICompactModeAuto + } +} + +func stringMappingFromRaw(raw any) map[string]string { + switch mapping := raw.(type) { + case map[string]any: + if len(mapping) == 0 { + return nil + } + result := make(map[string]string, len(mapping)) + for key, value := range mapping { + if str, ok := value.(string); ok { + result[key] = str + } + } + if len(result) == 0 { + return nil + } + return result + case map[string]string: + if len(mapping) == 0 { + return nil + } + result := make(map[string]string, len(mapping)) + for key, value := range mapping { + result[key] = value + } + return result + default: + return nil + } +} + func (a *Account) GetModelMapping() map[string]string { credentialsPtr := mapPtr(a.Credentials) rawMapping, _ := a.Credentials["model_mapping"].(map[string]any) @@ -618,24 +648,75 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, return requestedModel, false } -// AntigravityUpstreamType 标识 Antigravity APIKey 账号对接的上游形态。 -// -// - "sub2api"(默认):对接另一个 sub2api 实例,路径需要追加 /antigravity 前缀 -// - "newapi":对接 newapi/one-api 风格的中转,直接使用 /v1/messages -// -// 旧账号 credentials 中若缺失该字段,按 sub2api 处理以保持向后兼容。 -const ( - AntigravityUpstreamTypeSub2Api = "sub2api" - AntigravityUpstreamTypeNewAPI = "newapi" -) - -// GetAntigravityUpstreamType 返回该账号的上游类型(仅对 Antigravity+APIKey 有意义)。 -func (a *Account) GetAntigravityUpstreamType() string { - t := strings.ToLower(strings.TrimSpace(a.GetCredential("upstream_type"))) - if t == AntigravityUpstreamTypeNewAPI { - return AntigravityUpstreamTypeNewAPI +// GetOpenAICompactMode returns the compact routing mode for an OpenAI account. +// Missing or invalid values fall back to "auto". +func (a *Account) GetOpenAICompactMode() string { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return OpenAICompactModeAuto } - return AntigravityUpstreamTypeSub2Api + mode, _ := a.Extra["openai_compact_mode"].(string) + return normalizeOpenAICompactMode(mode) +} + +// OpenAICompactSupportKnown reports whether compact capability is known for this +// account and, when known, whether it is supported. +func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) { + if a == nil || !a.IsOpenAI() { + return false, false + } + + switch a.GetOpenAICompactMode() { + case OpenAICompactModeForceOn: + return true, true + case OpenAICompactModeForceOff: + return false, true + } + + if a.Extra == nil { + return false, false + } + supported, ok := a.Extra["openai_compact_supported"].(bool) + if !ok { + return false, false + } + return supported, true +} + +// AllowsOpenAICompact reports whether the account may be considered for compact +// requests. Unknown capability remains allowed to avoid breaking older accounts +// before an explicit probe has been run. +func (a *Account) AllowsOpenAICompact() bool { + if a == nil || !a.IsOpenAI() { + return false + } + supported, known := a.OpenAICompactSupportKnown() + if !known { + return true + } + return supported +} + +// GetCompactModelMapping returns compact-only model remapping configuration. +// This mapping is intended for /responses/compact only and does not affect +// normal /responses traffic. +func (a *Account) GetCompactModelMapping() map[string]string { + if a == nil || a.Credentials == nil { + return nil + } + return stringMappingFromRaw(a.Credentials["compact_model_mapping"]) +} + +// ResolveCompactMappedModel resolves compact-only model remapping and reports +// whether a compact-specific mapping rule matched. +func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) { + mapping := a.GetCompactModelMapping() + if len(mapping) == 0 { + return requestedModel, false + } + if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched { + return mappedModel, true + } + return requestedModel, false } func (a *Account) GetBaseURL() string { @@ -646,25 +727,23 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } - if a.Platform == PlatformAntigravity && a.GetAntigravityUpstreamType() == AntigravityUpstreamTypeSub2Api { + if a.Platform == PlatformAntigravity { return strings.TrimRight(baseURL, "/") + "/antigravity" } - return strings.TrimRight(baseURL, "/") + return baseURL } // GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 -// Antigravity 平台的 APIKey 账号默认自动拼接 /antigravity; -// 若 upstream_type=newapi 则直接使用用户配置的 base_url。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { baseURL := strings.TrimSpace(a.GetCredential("base_url")) if baseURL == "" { return defaultBaseURL } - if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey && - a.GetAntigravityUpstreamType() == AntigravityUpstreamTypeSub2Api { + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { return strings.TrimRight(baseURL, "/") + "/antigravity" } - return strings.TrimRight(baseURL, "/") + return baseURL } func (a *Account) GetExtraString(key string) string { diff --git a/backend/internal/service/account_openai_compact_test.go b/backend/internal/service/account_openai_compact_test.go new file mode 100644 index 00000000..442b00da --- /dev/null +++ b/backend/internal/service/account_openai_compact_test.go @@ -0,0 +1,369 @@ +package service + +import "testing" + +func TestAccountGetOpenAICompactMode(t *testing.T) { + tests := []struct { + name string + account *Account + want string + }{ + { + name: "nil account defaults to auto", + want: OpenAICompactModeAuto, + }, + { + name: "non openai account defaults to auto", + account: &Account{ + Platform: PlatformAnthropic, + Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}, + }, + want: OpenAICompactModeAuto, + }, + { + name: "missing extra defaults to auto", + account: &Account{ + Platform: PlatformOpenAI, + }, + want: OpenAICompactModeAuto, + }, + { + name: "invalid mode falls back to auto", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_mode": " invalid "}, + }, + want: OpenAICompactModeAuto, + }, + { + name: "force on is normalized", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_mode": " FORCE_ON "}, + }, + want: OpenAICompactModeForceOn, + }, + { + name: "force off is normalized", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_mode": "force_off"}, + }, + want: OpenAICompactModeForceOff, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.account.GetOpenAICompactMode(); got != tt.want { + t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAccountOpenAICompactSupportKnown(t *testing.T) { + tests := []struct { + name string + account *Account + wantSupported bool + wantKnown bool + }{ + { + name: "nil account is unknown", + wantSupported: false, + wantKnown: false, + }, + { + name: "non openai account is unknown", + account: &Account{ + Platform: PlatformAnthropic, + Extra: map[string]any{"openai_compact_supported": true}, + }, + wantSupported: false, + wantKnown: false, + }, + { + name: "force on overrides probe state", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + "openai_compact_mode": OpenAICompactModeForceOn, + "openai_compact_supported": false, + }, + }, + wantSupported: true, + wantKnown: true, + }, + { + name: "force off overrides probe state", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + "openai_compact_mode": OpenAICompactModeForceOff, + "openai_compact_supported": true, + }, + }, + wantSupported: false, + wantKnown: true, + }, + { + name: "auto true is known supported", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_supported": true}, + }, + wantSupported: true, + wantKnown: true, + }, + { + name: "auto false is known unsupported", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_supported": false}, + }, + wantSupported: false, + wantKnown: true, + }, + { + name: "auto without probe state remains unknown", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{}, + }, + wantSupported: false, + wantKnown: false, + }, + { + name: "invalid probe field remains unknown", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_supported": "true"}, + }, + wantSupported: false, + wantKnown: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown() + if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown { + t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown) + } + }) + } +} + +func TestAccountAllowsOpenAICompact(t *testing.T) { + tests := []struct { + name string + account *Account + want bool + }{ + { + name: "nil account does not allow compact", + want: false, + }, + { + name: "non openai account does not allow compact", + account: &Account{ + Platform: PlatformAnthropic, + }, + want: false, + }, + { + name: "unknown openai account remains allowed", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{}, + }, + want: true, + }, + { + name: "supported openai account is allowed", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_supported": true}, + }, + want: true, + }, + { + name: "unsupported openai account is rejected", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_supported": false}, + }, + want: false, + }, + { + name: "force on is allowed", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}, + }, + want: true, + }, + { + name: "force off is rejected", + account: &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.account.AllowsOpenAICompact(); got != tt.want { + t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAccountGetCompactModelMapping(t *testing.T) { + tests := []struct { + name string + account *Account + want map[string]string + }{ + { + name: "nil account returns nil", + want: nil, + }, + { + name: "missing credentials returns nil", + account: &Account{ + Platform: PlatformOpenAI, + }, + want: nil, + }, + { + name: "map any is converted", + account: &Account{ + Credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4-openai-compact", + "invalid": 1, + }, + }, + }, + want: map[string]string{ + "gpt-5.4": "gpt-5.4-openai-compact", + }, + }, + { + name: "map string string is copied", + account: &Account{ + Credentials: map[string]any{ + "compact_model_mapping": map[string]string{ + "gpt-*": "compact-*", + }, + }, + }, + want: map[string]string{ + "gpt-*": "compact-*", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.account.GetCompactModelMapping() + if !equalStringMap(got, tt.want) { + t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want) + } + }) + } +} + +func TestAccountResolveCompactMappedModel(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expectedModel string + expectedMatch bool + }{ + { + name: "no compact mapping reports unmatched", + credentials: nil, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + { + name: "exact compact mapping matches", + credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4-openai-compact", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4-openai-compact", + expectedMatch: true, + }, + { + name: "exact passthrough counts as match", + credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "longest wildcard wins", + credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-*": "fallback-compact", + "gpt-5.4*": "gpt-5.4-openai-compact", + "gpt-5.4-mini*": "gpt-5.4-mini-openai-compact", + }, + }, + requestedModel: "gpt-5.4-mini", + expectedModel: "gpt-5.4-mini-openai-compact", + expectedMatch: true, + }, + { + name: "missing compact mapping reports unmatched", + credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.3": "gpt-5.3-openai-compact", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Credentials: tt.credentials, + } + gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel) + if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch { + t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch) + } + }) + } +} + +func equalStringMap(left, right map[string]string) bool { + if len(left) != len(right) { + return false + } + for key, want := range right { + if got, ok := left[key]; !ok || got != want { + return false + } + } + return true +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 4e51d95f..ce2a5dbe 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -174,7 +174,8 @@ func createTestPayload(modelID string, prompt string) (map[string]any, error) { // TestAccountConnection tests an account's connection by sending a test request // All account types use full Claude Code client characteristics, only auth header differs // modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error { +// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error { ctx := c.Request.Context() // Get account @@ -185,7 +186,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int // Route to platform-specific test method if account.IsOpenAI() { - return s.testOpenAIAccountConnection(c, account, modelID, prompt) + return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode)) } if account.IsGemini() { @@ -433,9 +434,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co } // testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error { ctx := c.Request.Context() _ = prompt + mode = normalizeAccountTestMode(mode) // Default to openai.DefaultTestModel for OpenAI testing testModelID := modelID @@ -443,14 +445,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account testModelID = openai.DefaultTestModel } - // For API Key accounts with model mapping, map the model - if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } + // Align test routing with gateway behavior: OpenAI accounts apply normal + // account model mapping, and compact mode applies compact-only mapping on top. + testModelID = account.GetMappedModel(testModelID) + if mode == AccountTestModeCompact { + testModelID = resolveOpenAICompactForwardModel(account, testModelID) + return s.testOpenAICompactConnection(c, account, testModelID) } // Route to image generation test if an image model is selected @@ -555,6 +555,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusTooManyRequests { + s.reconcileOpenAI429State(ctx, account, resp.Header, body) + } // 401 Unauthorized: 标记账号为永久错误 if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) @@ -567,6 +570,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.processOpenAIStream(c, resp.Body) } +// testOpenAICompactConnection probes /responses/compact and persists the +// resulting capability state on the account. +func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error { + ctx := c.Request.Context() + + authToken := "" + apiURL := "" + isOAuth := false + chatgptAccountID := "" + + switch { + case account.IsOAuth(): + isOAuth = true + authToken = account.GetOpenAIAccessToken() + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + apiURL = chatgptCodexAPIURL + "/compact" + chatgptAccountID = account.GetChatGPTAccountID() + case account.Type == AccountTypeAPIKey: + authToken = account.GetOpenAIApiKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact") + default: + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID)) + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("Originator", "codex_cli_rs") + req.Header.Set("User-Agent", codexCLIUserAgent) + req.Header.Set("Version", codexCLIVersion) + probeSessionID := compactProbeSessionID(account.ID) + req.Header.Set("Session_ID", probeSessionID) + req.Header.Set("Conversation_ID", probeSessionID) + + if isOAuth { + req.Host = "chatgpt.com" + if chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if s.accountRepo != nil { + updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now()) + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + if s.accountRepo != nil { + updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now()) + if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 { + updates = mergeExtraUpdates(updates, codexUpdates) + } + if len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + // 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。 + if resp.StatusCode == http.StatusTooManyRequests { + s.reconcileOpenAI429State(ctx, account, resp.Header, body) + } + } + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { + errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) { + if s == nil || s.accountRepo == nil || account == nil { + return + } + + var resetAt *time.Time + if calculated := calculateOpenAI429ResetTime(headers); calculated != nil { + resetAt = calculated + } else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil { + t := time.Unix(*unixTs, 0) + resetAt = &t + } + if resetAt == nil { + return + } + + if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { + return + } + + now := time.Now() + account.RateLimitedAt = &now + account.RateLimitResetAt = resetAt + + if account.Status == StatusError { + if err := s.accountRepo.ClearError(ctx, account.ID); err != nil { + return + } + account.Status = StatusActive + account.ErrorMessage = "" + } +} + // testGeminiAccountConnection tests a Gemini account's connection func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() @@ -1053,13 +1204,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) // processOpenAIStream processes the SSE stream from OpenAI Responses API func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error { reader := bufio.NewReader(body) + seenCompleted := false for { line, err := reader.ReadString('\n') if err != nil { if err == io.EOF { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil + if seenCompleted { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, "Stream ended before response.completed") } return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) } @@ -1071,8 +1226,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) jsonStr := sseDataPrefix.ReplaceAllString(line, "") if jsonStr == "[DONE]" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil + if seenCompleted { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, "Stream ended before response.completed") } var data map[string]any @@ -1088,9 +1246,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) if delta, ok := data["delta"].(string); ok && delta != "" { s.sendEvent(c, TestEvent{Type: "content", Text: delta}) } - case "response.completed": + case "response.completed", "response.done": s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil + case "response.failed": + errorMsg := "OpenAI response failed" + if responseData, ok := data["response"].(map[string]any); ok { + if errData, ok := responseData["error"].(map[string]any); ok { + if msg, ok := errData["message"].(string); ok && msg != "" { + errorMsg = msg + } + } + } + return s.sendErrorAndEnd(c, errorMsg) case "error": errorMsg := "Unknown error" if errData, ok := data["error"].(map[string]any); ok { @@ -1320,7 +1488,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in ginCtx, _ := gin.CreateTestContext(w) ginCtx.Request = (&http.Request{}).WithContext(ctx) - testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "") + testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault) finishedAt := time.Now() body := w.Body.String() diff --git a/backend/internal/service/account_test_service_openai_compact_test.go b/backend/internal/service/account_test_service_openai_compact_test.go new file mode 100644 index 00000000..9eb98fdc --- /dev/null +++ b/backend/internal/service/account_test_service_openai_compact_test.go @@ -0,0 +1,199 @@ +package service + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) { + gin.SetMode(gin.TestMode) + + updateCalls := make(chan map[string]any, 1) + account := Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + repo := &snapshotUpdateAccountRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + updateExtraCalls: updateCalls, + } + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)), + }} + svc := &AccountTestService{ + accountRepo: repo, + httpUpstream: upstream, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil)) + + err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact) + require.NoError(t, err) + + require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String()) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version")) + require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id")) + require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + + updates := <-updateCalls + require.Equal(t, true, updates["openai_compact_supported"]) + require.Equal(t, http.StatusOK, updates["openai_compact_last_status"]) + require.Contains(t, rec.Body.String(), `"type":"test_complete"`) +} + +func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) { + gin.SetMode(gin.TestMode) + + updateCalls := make(chan map[string]any, 1) + account := Account{ + ID: 2, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + repo := &snapshotUpdateAccountRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + updateExtraCalls: updateCalls, + } + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusNotFound, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`404 page not found`)), + }} + svc := &AccountTestService{ + accountRepo: repo, + httpUpstream: upstream, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil)) + + err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact) + require.Error(t, err) + + updates := <-updateCalls + require.Equal(t, false, updates["openai_compact_supported"]) + require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"]) + require.Contains(t, rec.Body.String(), `"type":"error"`) +} + +func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + updateCalls := make(chan map[string]any, 1) + account := Account{ + ID: 3, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://example.com/v1", + "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"}, + }, + } + repo := &snapshotUpdateAccountRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + updateExtraCalls: updateCalls, + } + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)), + }} + svc := &AccountTestService{ + accountRepo: repo, + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil)) + + err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact) + require.NoError(t, err) + + require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String()) + require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String()) + updates := <-updateCalls + require.Equal(t, true, updates["openai_compact_supported"]) +} + +func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) { + gin.SetMode(gin.TestMode) + + updateCalls := make(chan map[string]any, 1) + account := Account{ + ID: 4, + Name: "openai-apikey-default", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + } + repo := &snapshotUpdateAccountRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + updateExtraCalls: updateCalls, + } + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)), + }} + svc := &AccountTestService{ + accountRepo: repo, + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil)) + + err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact) + require.NoError(t, err) + require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String()) + <-updateCalls +} diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 82ff0a8b..56204be3 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) { type openAIAccountTestRepo struct { mockAccountRepoForGemini - updatedExtra map[string]any - rateLimitedID int64 - rateLimitedAt *time.Time + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time + clearedErrorID int64 + setErrorID int64 + setErrorMsg string } func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { @@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese return nil } +func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error { + r.clearedErrorID = id + return nil +} + +func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error { + r.setErrorID = id + r.setErrorMsg = errorMsg + return nil +} + func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { gin.SetMode(gin.TestMode) ctx, recorder := newTestContext() @@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") require.NoError(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) @@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. require.Contains(t, recorder.Body.String(), "test_complete") } -func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) { +func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"} + +`)) + + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 90, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Contains(t, recorder.Body.String(), "response.completed") + require.NotContains(t, recorder.Body.String(), `"success":true`) +} + +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) { gin.SetMode(gin.TestMode) ctx, _ := newTestContext() - resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`) resp.Header.Set("x-codex-primary-used-percent", "100") resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") resp.Header.Set("x-codex-primary-window-minutes", "10080") @@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing ID: 88, Platform: PlatformOpenAI, Type: AccountTypeOAuth, + Status: StatusError, Concurrency: 1, Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Equal(t, account.ID, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.Empty(t, account.ErrorMessage) + require.NotNil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 77, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions", + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Equal(t, account.ID, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.Empty(t, account.ErrorMessage) + require.NotNil(t, account.RateLimitResetAt) + require.Empty(t, repo.updatedExtra) +} + +func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 78, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Zero(t, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.NotNil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 79, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "stale 403", + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) require.Zero(t, repo.rateLimitedID) require.Nil(t, repo.rateLimitedAt) + require.Zero(t, repo.clearedErrorID) + require.Equal(t, StatusError, account.Status) + require.Equal(t, "stale 403", account.ErrorMessage) + require.Nil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 80, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, account.ID, repo.setErrorID) + require.Contains(t, repo.setErrorMsg, "Authentication failed (401)") + require.Zero(t, repo.rateLimitedID) + require.Zero(t, repo.clearedErrorID) require.Nil(t, account.RateLimitResetAt) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index d1775fd5..b8fc1d4c 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -110,7 +110,7 @@ const ( apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 windowStatsCacheTTL = 1 * time.Minute openAIProbeCacheTTL = 10 * time.Minute - openAICodexProbeVersion = "0.104.0" + openAICodexProbeVersion = "0.125.0" ) // UsageCache 封装账户使用量相关的缓存 diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go new file mode 100644 index 00000000..aca32076 --- /dev/null +++ b/backend/internal/service/affiliate_service.go @@ -0,0 +1,448 @@ +package service + +import ( + "context" + "errors" + "math" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +var ( + ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found") + ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code") + ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use") + ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound") + ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer") +) + +const ( + affiliateInviteesLimit = 100 + // AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated + // 12-char codes and admin-customized codes (e.g. "VIP2026"). + AffiliateCodeMinLength = 4 + AffiliateCodeMaxLength = 32 +) + +// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash. +// All input passes through strings.ToUpper before validation, so lowercase from +// users is normalized — admins may supply mixed case in their UI. +var affiliateCodeValidChar = func() [256]bool { + var tbl [256]bool + for c := byte('A'); c <= 'Z'; c++ { + tbl[c] = true + } + for c := byte('0'); c <= '9'; c++ { + tbl[c] = true + } + tbl['_'] = true + tbl['-'] = true + return tbl +}() + +// isValidAffiliateCodeFormat validates code format for both binding (user input) +// and admin updates. Caller is expected to upper-case the input first. +func isValidAffiliateCodeFormat(code string) bool { + if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength { + return false + } + for i := 0; i < len(code); i++ { + if !affiliateCodeValidChar[code[i]] { + return false + } + } + return true +} + +type AffiliateSummary struct { + UserID int64 `json:"user_id"` + AffCode string `json:"aff_code"` + AffCodeCustom bool `json:"aff_code_custom"` + AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"` + InviterID *int64 `json:"inviter_id,omitempty"` + AffCount int `json:"aff_count"` + AffQuota float64 `json:"aff_quota"` + AffHistoryQuota float64 `json:"aff_history_quota"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type AffiliateInvitee struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + CreatedAt *time.Time `json:"created_at,omitempty"` +} + +type AffiliateDetail struct { + UserID int64 `json:"user_id"` + AffCode string `json:"aff_code"` + InviterID *int64 `json:"inviter_id,omitempty"` + AffCount int `json:"aff_count"` + AffQuota float64 `json:"aff_quota"` + AffHistoryQuota float64 `json:"aff_history_quota"` + // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例: + // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。 + // 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。 + EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"` + Invitees []AffiliateInvitee `json:"invitees"` +} + +type AffiliateRepository interface { + EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) + GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) + BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) + TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) + ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error) + + // 管理端:用户级专属配置 + UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error + ResetUserAffCode(ctx context.Context, userID int64) (string, error) + SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error + BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error + ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) +} + +// AffiliateAdminFilter 列表筛选条件 +type AffiliateAdminFilter struct { + Search string + Page int + PageSize int +} + +// AffiliateAdminEntry 专属用户列表条目 +type AffiliateAdminEntry struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + AffCode string `json:"aff_code"` + AffCodeCustom bool `json:"aff_code_custom"` + AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"` + AffCount int `json:"aff_count"` +} + +type AffiliateService struct { + repo AffiliateRepository + settingService *SettingService + authCacheInvalidator APIKeyAuthCacheInvalidator + billingCacheService *BillingCacheService +} + +func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService { + return &AffiliateService{ + repo: repo, + settingService: settingService, + authCacheInvalidator: authCacheInvalidator, + billingCacheService: billingCacheService, + } +} + +// IsEnabled reports whether the affiliate (邀请返利) feature is turned on. +func (s *AffiliateService) IsEnabled(ctx context.Context) bool { + if s == nil || s.settingService == nil { + return AffiliateEnabledDefault + } + return s.settingService.IsAffiliateEnabled(ctx) +} + +func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER", "invalid user") + } + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.EnsureUserAffiliate(ctx, userID) +} + +func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) { + summary, err := s.EnsureUserAffiliate(ctx, userID) + if err != nil { + return nil, err + } + invitees, err := s.listInvitees(ctx, userID) + if err != nil { + return nil, err + } + return &AffiliateDetail{ + UserID: summary.UserID, + AffCode: summary.AffCode, + InviterID: summary.InviterID, + AffCount: summary.AffCount, + AffQuota: summary.AffQuota, + AffHistoryQuota: summary.AffHistoryQuota, + EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary), + Invitees: invitees, + }, nil +} + +func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error { + code := strings.ToUpper(strings.TrimSpace(rawCode)) + if code == "" { + return nil + } + if s == nil || s.repo == nil { + return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + // 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程) + if !s.IsEnabled(ctx) { + return nil + } + if !isValidAffiliateCodeFormat(code) { + return ErrAffiliateCodeInvalid + } + + selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID) + if err != nil { + return err + } + if selfSummary.InviterID != nil { + return nil + } + + inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code) + if err != nil { + if errors.Is(err, ErrAffiliateProfileNotFound) { + return ErrAffiliateCodeInvalid + } + return err + } + if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID { + return ErrAffiliateCodeInvalid + } + + bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID) + if err != nil { + return err + } + if !bound { + return ErrAffiliateAlreadyBound + } + return nil +} + +func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { + if s == nil || s.repo == nil { + return 0, nil + } + if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) { + return 0, nil + } + // 总开关关闭时,新充值不再产生返利 + if !s.IsEnabled(ctx) { + return 0, nil + } + + inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID) + if err != nil { + return 0, err + } + if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 { + return 0, nil + } + + // 加载邀请人 profile,优先使用专属比例(覆盖全局) + inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID) + if err != nil { + return 0, err + } + rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary) + rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8) + if rebate <= 0 { + return 0, nil + } + + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate) + if err != nil { + return 0, err + } + if !applied { + return 0, nil + } + return rebate, nil +} + +// resolveRebateRatePercent returns the inviter's exclusive rate when set, +// otherwise the global setting value (clamped to [Min, Max]). +func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 { + if inviter != nil && inviter.AffRebateRatePercent != nil { + v := *inviter.AffRebateRatePercent + if math.IsNaN(v) || math.IsInf(v, 0) { + return s.globalRebateRatePercent(ctx) + } + return clampAffiliateRebateRate(v) + } + return s.globalRebateRatePercent(ctx) +} + +// globalRebateRatePercent reads the system-wide rebate rate via SettingService, +// returning the documented default when SettingService is unavailable. +func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 { + if s == nil || s.settingService == nil { + return AffiliateRebateRateDefault + } + return s.settingService.GetAffiliateRebateRatePercent(ctx) +} + +func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) { + if s == nil || s.repo == nil { + return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + + transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID) + if err != nil { + return 0, 0, err + } + if transferred > 0 { + s.invalidateAffiliateCaches(ctx, userID) + } + return transferred, balance, nil +} + +func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) { + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit) + if err != nil { + return nil, err + } + for i := range invitees { + invitees[i].Email = maskEmail(invitees[i].Email) + } + return invitees, nil +} + +func roundTo(v float64, scale int) float64 { + factor := math.Pow10(scale) + return math.Round(v*factor) / factor +} + +func maskEmail(email string) string { + email = strings.TrimSpace(email) + if email == "" { + return "" + } + at := strings.Index(email, "@") + if at <= 0 || at >= len(email)-1 { + return "***" + } + + local := email[:at] + domain := email[at+1:] + dot := strings.LastIndex(domain, ".") + + maskedLocal := maskSegment(local) + if dot <= 0 || dot >= len(domain)-1 { + return maskedLocal + "@" + maskSegment(domain) + } + + domainName := domain[:dot] + tld := domain[dot:] + return maskedLocal + "@" + maskSegment(domainName) + tld +} + +func maskSegment(s string) string { + r := []rune(s) + if len(r) == 0 { + return "***" + } + if len(r) == 1 { + return string(r[0]) + "***" + } + return string(r[0]) + "***" +} + +func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) { + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService != nil { + if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil { + logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err) + } + } +} + +// ========================= +// Admin: 专属配置管理 +// ========================= + +// validateExclusiveRate ensures a per-user override is finite and within +// [Min, Max]. nil is always valid (means "clear / fall back to global"). +func validateExclusiveRate(ratePercent *float64) error { + if ratePercent == nil { + return nil + } + v := *ratePercent + if math.IsNaN(v) || math.IsInf(v, 0) { + return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate") + } + if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax { + return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range") + } + return nil +} + +// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。 +func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error { + if s == nil || s.repo == nil { + return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + code := strings.ToUpper(strings.TrimSpace(rawCode)) + if !isValidAffiliateCodeFormat(code) { + return ErrAffiliateCodeInvalid + } + return s.repo.UpdateUserAffCode(ctx, userID, code) +} + +// AdminResetUserAffCode 重置用户邀请码为系统随机码。 +func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) { + if s == nil || s.repo == nil { + return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ResetUserAffCode(ctx, userID) +} + +// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。 +func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error { + if s == nil || s.repo == nil { + return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + if err := validateExclusiveRate(ratePercent); err != nil { + return err + } + return s.repo.SetUserRebateRate(ctx, userID, ratePercent) +} + +// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。 +func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error { + if s == nil || s.repo == nil { + return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + if err := validateExclusiveRate(ratePercent); err != nil { + return err + } + cleaned := make([]int64, 0, len(userIDs)) + for _, uid := range userIDs { + if uid > 0 { + cleaned = append(cleaned, uid) + } + } + if len(cleaned) == 0 { + return nil + } + return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent) +} + +// AdminListCustomUsers 列出有专属配置的用户。 +func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListUsersWithCustomSettings(ctx, filter) +} diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go new file mode 100644 index 00000000..c02a4dd7 --- /dev/null +++ b/backend/internal/service/affiliate_service_test.go @@ -0,0 +1,131 @@ +//go:build unit + +package service + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter +// AffRebateRatePercent overrides the global rate, that NULL falls back to the +// global rate, and that out-of-range exclusive rates are clamped silently. +// +// SettingService is left nil here so globalRebateRatePercent returns the +// documented default (AffiliateRebateRateDefault = 20%) — this exercises the +// fallback path without spinning up a settings stub. +func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) { + t.Parallel() + svc := &AffiliateService{} + + // nil exclusive rate → falls back to global default (20%) + require.InDelta(t, AffiliateRebateRateDefault, + svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9) + + // exclusive rate set → overrides global + rate := 50.0 + require.InDelta(t, 50.0, + svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9) + + // exclusive rate 0 → returns 0 (no rebate, intentional) + zero := 0.0 + require.InDelta(t, 0.0, + svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9) + + // exclusive rate above max → clamped to Max + tooHigh := 250.0 + require.InDelta(t, AffiliateRebateRateMax, + svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9) + + // exclusive rate below min → clamped to Min + tooLow := -5.0 + require.InDelta(t, AffiliateRebateRateMin, + svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9) +} + +// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled +// safely handles a nil settingService dependency by returning the default +// (off). This protects callers from nil-pointer crashes in misconfigured +// environments. +func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) { + t.Parallel() + svc := &AffiliateService{} + require.False(t, svc.IsEnabled(context.Background())) + require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background())) +} + +// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by +// admin-facing rate setters: nil is always valid (clear), in-range values +// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest. +func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) { + t.Parallel() + require.NoError(t, validateExclusiveRate(nil)) + + for _, v := range []float64{0, 0.01, 50, 99.99, 100} { + v := v + require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v) + } + + for _, v := range []float64{-0.01, 100.01, -100, 200} { + v := v + require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v) + } + + nan := math.NaN() + require.Error(t, validateExclusiveRate(&nan)) + posInf := math.Inf(1) + require.Error(t, validateExclusiveRate(&posInf)) + negInf := math.Inf(-1) + require.Error(t, validateExclusiveRate(&negInf)) +} + +func TestMaskEmail(t *testing.T) { + t.Parallel() + require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com")) + require.Equal(t, "x***@d***", maskEmail("x@domain")) + require.Equal(t, "", maskEmail("")) +} + +func TestIsValidAffiliateCodeFormat(t *testing.T) { + t.Parallel() + + // 邀请码格式校验同时服务于: + // 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1) + // 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1") + // 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。 + cases := []struct { + name string + in string + want bool + }{ + {"valid canonical 12-char", "ABCDEFGHJKLM", true}, + {"valid all digits 2-9", "234567892345", true}, + {"valid mixed", "A2B3C4D5E6F7", true}, + {"valid admin custom short", "VIP1", true}, + {"valid admin custom with hyphen", "NEW-USER", true}, + {"valid admin custom with underscore", "VIP_2026", true}, + {"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true}, + // Previously-excluded chars (I/O/0/1) are now allowed since admins may use them. + {"letter I now allowed", "IBCDEFGHJKLM", true}, + {"letter O now allowed", "OBCDEFGHJKLM", true}, + {"digit 0 now allowed", "0BCDEFGHJKLM", true}, + {"digit 1 now allowed", "1BCDEFGHJKLM", true}, + {"too short (3 chars)", "ABC", false}, + {"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false}, + {"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false}, + {"empty", "", false}, + {"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset + {"ascii punctuation .", "ABCDEFGHJK.M", false}, + {"whitespace", "ABCDEFGHJK M", false}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in)) + }) + } +} diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index e3fb2f85..21d9d6e9 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService( nil, nil, nil, + nil, ) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index e45d8d66..08b0f4b7 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -72,6 +72,7 @@ type AuthService struct { turnstileService *TurnstileService emailQueueService *EmailQueueService promoService *PromoService + affiliateService *AffiliateService defaultSubAssigner DefaultSubscriptionAssigner } @@ -98,6 +99,7 @@ func NewAuthService( emailQueueService *EmailQueueService, promoService *PromoService, defaultSubAssigner DefaultSubscriptionAssigner, + affiliateService *AffiliateService, ) *AuthService { return &AuthService{ entClient: entClient, @@ -110,6 +112,7 @@ func NewAuthService( turnstileService: turnstileService, emailQueueService: emailQueueService, promoService: promoService, + affiliateService: affiliateService, defaultSubAssigner: defaultSubAssigner, } } @@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client { // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { - return s.RegisterWithVerification(ctx, email, password, "", "", "") + return s.RegisterWithVerification(ctx, email, password, "", "", "", "") } -// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户 -func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) { +// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。 +func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) { // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled @@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } s.postAuthUserBootstrap(ctx, user, "email", true) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + if s.affiliateService != nil { + if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err) + } + if code := strings.TrimSpace(affiliateCode); code != "" { + if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil { + // 邀请返利码绑定失败不影响注册,只记录日志 + logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err) + } + } + } // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index 226eb8e8..e2392e4b 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( emailSvc = service.NewEmailService(settingRepo, emailCache) } - svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil) return svc, repo, client } @@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t }, } emailService := service.NewEmailService(nil, cache) - svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil) oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ ID: 41, diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 2233e427..53048b92 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( values: settings, }, cfg) - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner) + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil) return svc, repo, client } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index dbd18a20..c1ad6240 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, // promoService nil, // defaultSubAssigner + nil, // affiliateService ) } @@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi }, nil) // 应返回服务不可用错误,而不是允许绕过验证 - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "") require.ErrorIs(t, err, ErrServiceUnavailable) } @@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "") require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorContains(t, err, "verify code") } diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 477ba1b2..3512822f 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier nil, // emailQueueService nil, // promoService nil, // defaultSubAssigner + nil, // affiliateService ) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 5e9279a8..1c8e7cc9 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -18,6 +18,14 @@ const ( RoleUser = domain.RoleUser ) +// Affiliate rebate settings +const ( + AffiliateRebateRateDefault = 20.0 + AffiliateRebateRateMin = 0.0 + AffiliateRebateRateMax = 100.0 + AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 +) + // Platform constants const ( PlatformAnthropic = domain.PlatformAnthropic @@ -88,6 +96,8 @@ const ( SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接 SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 + SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关 + SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5be1f733..428231ee 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock( system := gjson.GetBytes(upstream.lastBody, "system") require.True(t, system.Exists()) require.True(t, system.IsArray(), "system should be an array") - require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String()) - require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String()) + arr := system.Array() + require.Len(t, arr, 2, "system array should have billing block + cc prompt block") + + require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:") + require.Contains(t, arr[0].Get("text").String(), "cc_version=") + + require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String()) + require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String()) // 原始 system prompt 应迁移至 messages 中 messages := gjson.GetBytes(upstream.lastBody, "messages") diff --git a/backend/internal/service/gateway_billing_block.go b/backend/internal/service/gateway_billing_block.go new file mode 100644 index 00000000..45c307fd --- /dev/null +++ b/backend/internal/service/gateway_billing_block.go @@ -0,0 +1,98 @@ +package service + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/tidwall/gjson" +) + +// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。 +// +// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致; +// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致, +// 进一步触发 Anthropic 的第三方检测。 +const fingerprintSalt = "59cf53e54c78" + +// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法: +// +// 1. 取 messages 中第一条 role=user 的纯文本(首块 text) +// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐) +// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符 +// +// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint,与官方 CLI 字节对齐。 +// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。 +func computeClaudeCodeFingerprint(body []byte, version string) string { + firstText := extractFirstUserText(body) + indices := []int{4, 7, 20} + chars := make([]byte, 0, 3) + for _, i := range indices { + if i < len(firstText) { + chars = append(chars, firstText[i]) + } else { + chars = append(chars, '0') + } + } + sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version)) + return hex.EncodeToString(sum[:])[:3] +} + +// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。 +// 兼容 string 和 []block 两种 content 格式。 +func extractFirstUserText(body []byte) string { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return "" + } + first := "" + messages.ForEach(func(_, msg gjson.Result) bool { + if msg.Get("role").String() != "user" { + return true + } + content := msg.Get("content") + if content.Type == gjson.String { + first = content.String() + return false + } + if content.IsArray() { + content.ForEach(func(_, block gjson.Result) bool { + if block.Get("type").String() == "text" { + first = block.Get("text").String() + return false + } + return true + }) + return false + } + return false + }) + return first +} + +// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。 +// +// 形态严格对齐真实 Claude Code CLI: +// +// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"} +// +// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段 +// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。 +// +// 此 block 不带 cache_control(与真实 CLI 一致;cache breakpoint 由后续的 +// Claude Code prompt block 承担)。 +func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) { + if cliVersion == "" { + return nil, fmt.Errorf("cliVersion required") + } + fp := computeClaudeCodeFingerprint(body, cliVersion) + text := fmt.Sprintf( + "x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;", + cliVersion, fp, + ) + return json.Marshal(map[string]string{ + "type": "text", + "text": text, + }) +} diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go index cfaf0a6a..e6c9de7d 100644 --- a/backend/internal/service/gateway_body_order_test.go +++ b/backend/internal/service/gateway_body_order_test.go @@ -41,13 +41,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing. resultStr := string(result) require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID) - assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"tool_choice"`, `"omega"`, `"tools"`, `"metadata"`) - // temperature 和 tool_choice 不再剥离,透传客户端原始值(与真实 CLI 行为一致) - require.Contains(t, resultStr, `"temperature"`) - require.Contains(t, resultStr, `"tool_choice"`) + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`) + require.Contains(t, resultStr, `"temperature":0.2`) + require.NotContains(t, resultStr, `"tool_choice"`) require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`) require.Contains(t, resultStr, `"tools":[]`) require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`) + require.Contains(t, resultStr, `"max_tokens":128000`) } func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) { diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index 37b38f76..c531667e 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -85,15 +85,16 @@ func (s *GatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("marshal anthropic request: %w", err) } - // 6. Apply Claude Code mimicry for OAuth accounts - isClaudeCode := false // CC API is never Claude Code + // 6. Apply Claude Code mimicry for OAuth accounts. + // Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号 + // 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入), + // 否则会被 Anthropic 判为第三方应用并扣 extra usage。 + // 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。 + isClaudeCode := false shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - if !strings.Contains(strings.ToLower(mappedModel), "haiku") && - !systemIncludesClaudeCodePrompt(anthropicReq.System) { - anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) - } + anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel) } // 7. Enforce cache_control block limit @@ -312,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } - c.JSON(http.StatusOK, ccResp) + // Marshal then bytes-replace so tool name mapping is reversed at byte level + // (parity with Parrot non-stream flow that marshals → restore → emit). + if respBytes, err := json.Marshal(ccResp); err == nil { + respBytes = reverseToolNamesIfPresent(c, respBytes) + c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes) + } else { + c.JSON(http.StatusOK, ccResp) + } return &ForwardResult{ RequestID: requestID, @@ -383,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic( if err != nil { return false } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { + // Reverse tool name mapping: fake → real, per-chunk bytes.Replace. + // c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。 + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + if _, err := fmt.Fprint(c.Writer, out); err != nil { return true // client disconnected } return false diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 2c917112..647193d6 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -82,15 +82,16 @@ func (s *GatewayService) ForwardAsResponses( return nil, fmt.Errorf("marshal anthropic request: %w", err) } - // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints) - isClaudeCode := false // Responses API is never Claude Code + // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints). + // OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号 + // 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入), + // 否则会被 Anthropic 判为第三方应用并扣 extra usage。 + // 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。 + isClaudeCode := false shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - if !strings.Contains(strings.ToLower(mappedModel), "haiku") && - !systemIncludesClaudeCodePrompt(anthropicReq.System) { - anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) - } + anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel) } // 7. Enforce cache_control block limit @@ -331,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse( if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } - c.JSON(http.StatusOK, responsesResp) + if respBytes, err := json.Marshal(responsesResp); err == nil { + respBytes = reverseToolNamesIfPresent(c, respBytes) + c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes) + } else { + c.JSON(http.StatusOK, responsesResp) + } return &ForwardResult{ RequestID: requestID, @@ -419,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse( ) continue } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + if _, err := fmt.Fprint(c.Writer, out); err != nil { logger.L().Info("forward_as_responses stream: client disconnected", zap.String("request_id", requestID), ) @@ -439,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse( if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + fmt.Fprint(c.Writer, out) //nolint:errcheck } c.Writer.Flush() } diff --git a/backend/internal/service/gateway_messages_cache.go b/backend/internal/service/gateway_messages_cache.go new file mode 100644 index 00000000..cb5384ba --- /dev/null +++ b/backend/internal/service/gateway_messages_cache.go @@ -0,0 +1,141 @@ +package service + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。 +// 与 Parrot _strip_message_cache_control 语义一致。 +// +// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在 +// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条 +// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。 +// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。 +func stripMessageCacheControl(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return body + } + msgIdx := -1 + messages.ForEach(func(_, msg gjson.Result) bool { + msgIdx++ + content := msg.Get("content") + if !content.IsArray() { + return true + } + blockIdx := -1 + content.ForEach(func(_, block gjson.Result) bool { + blockIdx++ + if !block.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx) + if next, err := sjson.DeleteBytes(body, path); err == nil { + body = next + } + return true + }) + return true + }) + return body +} + +// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点: +// 1. 最后一条 message +// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message +// +// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点 +// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。 +// +// cache_control ttl 策略: +// - 若目标 block 已有 cache_control.ttl → 不覆盖 +// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL} +// +// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。 +func addMessageCacheBreakpoints(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return body + } + arr := messages.Array() + if len(arr) == 0 { + return body + } + + body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1]) + + if len(arr) >= 4 { + userCount := 0 + for i := len(arr) - 1; i >= 0; i-- { + if arr[i].Get("role").String() != "user" { + continue + } + userCount++ + if userCount == 2 { + body = injectCacheControlOnLastContentBlock(body, i, &arr[i]) + break + } + } + } + + return body +} + +// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx] +// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组 +// (对齐 Parrot _inject_cache_on_msg 的行为)。 +// +// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。 +func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte { + content := msg.Get("content") + + if content.Type == gjson.String { + text := content.String() + blockRaw := fmt.Sprintf( + `[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`, + mustJSONString(text), claude.DefaultCacheControlTTL, + ) + if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil { + body = next + } + return body + } + + if !content.IsArray() { + return body + } + contentArr := content.Array() + if len(contentArr) == 0 { + return body + } + lastBlockIdx := len(contentArr) - 1 + lastBlock := contentArr[lastBlockIdx] + + if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" { + return body + } + + pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx) + existingCC := lastBlock.Get("cache_control") + if existingCC.Exists() { + if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil { + body = next + } + return body + } + raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL) + if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil { + body = next + } + return body +} + +// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号), +// 用于 sjson.SetRawBytes 场景下手工拼 JSON。 +func mustJSONString(s string) string { + return fmt.Sprintf("%q", s) +} diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index e27e18aa..443486ab 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -378,16 +378,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { err := json.Unmarshal(result, &parsed) require.NoError(t, err) - // system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] + // system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态: + // [0] billing attribution block (x-anthropic-billing-header: cc_version=...;) + // [1] Claude Code prompt block (带 cache_control) systemArr, ok := parsed["system"].([]any) require.True(t, ok, "system should be an array, got %T", parsed["system"]) - require.Len(t, systemArr, 1, "system array should have exactly 1 block") - systemBlock, ok := systemArr[0].(map[string]any) + require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)") + + billingBlock, ok := systemArr[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "text", billingBlock["type"]) + require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:") + require.Contains(t, billingBlock["text"], "cc_version=") + require.Contains(t, billingBlock["text"], "cc_entrypoint=cli") + require.Contains(t, billingBlock["text"], "cch=00000") + + systemBlock, ok := systemArr[1].(map[string]any) require.True(t, ok) require.Equal(t, "text", systemBlock["type"]) require.Equal(t, tt.wantSystemText, systemBlock["text"]) cc, ok := systemBlock["cache_control"].(map[string]any) - require.True(t, ok, "system block should have cache_control") + require.True(t, ok, "cc prompt block should have cache_control") require.Equal(t, "ephemeral", cc["type"]) // 检查 messages diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 54127cae..f6185ff2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -122,7 +122,7 @@ func openAIStreamEventIsTerminal(data string) bool { return true } switch gjson.Get(trimmed, "type").String() { - case "response.completed", "response.done", "response.failed": + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return true default: return false @@ -861,6 +861,7 @@ func (s *GatewayService) hashContent(content string) string { type anthropicCacheControlPayload struct { Type string `json:"type"` + TTL string `json:"ttl,omitempty"` } type anthropicSystemTextBlockPayload struct { @@ -912,7 +913,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b Text: text, } if includeCacheControl { - block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"} + block.CacheControl = &anthropicCacheControlPayload{ + Type: "ephemeral", + TTL: claude.DefaultCacheControlTTL, + } } return json.Marshal(block) } @@ -1088,11 +1092,51 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu } } - // 注意:不再剥离 temperature 和 tool_choice。 - // 真实 CLI 在 thinking 关闭时发 temperature:1,透传 tool_choice。 - // 之前无条件剥离会导致: - // 1. temperature=0 的确定性请求被静默忽略 - // 2. tool_choice 强制工具调用被静默变成 auto 模式 + // temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。 + // 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。 + // 策略:客户端传了什么就透传;没传则补默认 1。 + if !gjson.GetBytes(out, "temperature").Exists() { + if next, ok := setJSONValueBytes(out, "temperature", 1); ok { + out = next + modified = true + } + } + + // max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。 + if !gjson.GetBytes(out, "max_tokens").Exists() { + if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok { + out = next + modified = true + } + } + + // context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动 + // 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。 + // 客户端显式传了就透传;否则按 CLI 行为补齐。 + if !gjson.GetBytes(out, "context_management").Exists() { + thinkingType := gjson.GetBytes(out, "thinking.type").String() + if thinkingType == "enabled" || thinkingType == "adaptive" { + const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}` + if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok { + out = next + modified = true + } + } + } + + // tool_choice:与 Parrot 对齐,不再无条件删除。 + // - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由 + // applyToolNameRewriteToBody 同步映射为假名 + // - 其他形态(auto/any/none)原样透传 + // 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除 + if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 { + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } + } + } if !modified { return body, modelID @@ -1135,6 +1179,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) } +// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号" +// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。 +// +// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode + +// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层 +// (ForwardAsChatCompletions / ForwardAsResponses) 复用。 +// +// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加), +// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic +// 第三方检测";那条注释就是本函数存在的根因。 +// +// 参数: +// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。 +// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。 +// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。 +// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。 +// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。 +// +// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。 +func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + systemRaw any, + model string, +) []byte { + if account == nil || !account.IsOAuth() || len(body) == 0 { + return body + } + + systemRewritten := false + if !strings.Contains(strings.ToLower(model), "haiku") { + body = rewriteSystemForNonClaudeCode(body, systemRaw) + systemRewritten = true + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + + if s.identityService != nil && c != nil && c.Request != nil { + if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil { + mimicMPT := false + if s.settingService != nil { + _, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx) + } + if !mimicMPT { + if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = uid + } + } + } + } + + body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts) + + // Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点 + // 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束: + // 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性) + // 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn) + // 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1] + // 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。 + body = stripMessageCacheControl(body) + body = addMessageCacheBreakpoints(body) + + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + if c != nil { + c.Set(toolNameRewriteKey, rw) + } + } else { + body = applyToolsLastCacheBreakpoint(body) + } + + return body +} + +// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体, +// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。 +// +// 与 buildOAuthMetadataUserID 的唯一区别: +// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。 +// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID +// 自行决定是否覆盖)。 +func (s *GatewayService) buildOAuthMetadataUserIDFromBody( + ctx context.Context, + account *Account, + fp *Fingerprint, + body []byte, +) string { + _ = ctx + if account == nil { + return "" + } + if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + userID = generateClientID() + } + + sessionID := uuid.NewString() + if hash := hashBodyForSessionSeed(body); hash != "" { + sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash)) + } + + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) + } + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) +} + +// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。 +// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。 +func hashBodyForSessionSeed(body []byte) string { + if len(body) == 0 { + return "" + } + sum := sha256.Sum256(body) + return fmt.Sprintf("%x", sum[:16]) +} + // GenerateSessionUUID creates a deterministic UUID4 from a seed string. func GenerateSessionUUID(seed string) string { return generateSessionUUID(seed) @@ -3579,16 +3752,6 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { return claudeCliUserAgentRe.MatchString(userAgent) } -func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { - if IsClaudeCodeClient(ctx) { - return true - } - if parsed == nil || c == nil { - return false - } - return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) -} - // normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), // 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。 // 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 @@ -3765,17 +3928,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { originalSystemText = strings.Join(parts, "\n\n") } - // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) - // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 - // 使用 string 格式会被 Anthropic 检测为第三方应用。 - claudeCodeSystemBlock := []map[string]any{ - { - "type": "text", - "text": claudeCodeSystemPrompt, - "cache_control": map[string]string{"type": "ephemeral"}, - }, + // 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态: + // [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;) + // [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点) + // + // billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的 + // signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload + // 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。 + billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion) + ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if billingErr != nil || ccErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr) + return body } - out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) + out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock})) if !ok { logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") return body @@ -4012,15 +4178,21 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) } - isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) - shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode - systemRewritten := false // hoisted: tracks whether rewriteSystemForNonClaudeCode was called + // OAuth 账号无条件走完整 mimicry,与 Parrot 对齐。 + // 不再检查 isClaudeCodeRequest —— 即使客户端自称 Claude Code(opencode 等 + // 第三方工具会伪装 UA / X-App / system prompt),它的伪装往往不完整(缺 billing + // block / 工具名混淆 / cache 策略等),被 Anthropic 判为 third-party。 + // 无条件覆盖不会对真正的 Claude Code 造成问题,因为我们的伪装更完整。 + shouldMimicClaudeCode := account.IsOAuth() if shouldMimicClaudeCode { - // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages - // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - if !strings.Contains(strings.ToLower(reqModel), "haiku") && - !systemIncludesClaudeCodePrompt(parsed.System) { + // 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code + // 风格的 system prompt)。原因:第三方工具(opencode 等)会发 "You are Claude + // Code..." system prompt 但缺少 billing attribution block,导致 Anthropic + // 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。 + // Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。 + systemRewritten := false + if !strings.Contains(strings.ToLower(reqModel), "haiku") { body = rewriteSystemForNonClaudeCode(body, parsed.System) systemRewritten = true } @@ -4044,6 +4216,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + + // D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点 + // 与 forward_as_chat_completions / forward_as_responses 路径对齐, + // 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。 + body = stripMessageCacheControl(body) + body = addMessageCacheBreakpoints(body) + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + c.Set(toolNameRewriteKey, rw) + } else { + body = applyToolsLastCacheBreakpoint(body) + } } // 注入 x-anthropic-billing-header attribution block(所有 OAuth 账号) @@ -5036,7 +5220,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( } if !clientDisconnected { - if _, err := io.WriteString(w, line); err != nil { + restored := string(reverseToolNamesIfPresent(c, []byte(line))) + if _, err := io.WriteString(w, restored); err != nil { clientDisconnected = true logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) } else if _, err := io.WriteString(w, "\n"); err != nil { @@ -5206,6 +5391,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( if contentType == "" { contentType = "application/json" } + body = reverseToolNamesIfPresent(c, body) c.Data(resp.StatusCode, contentType, body) return usage, nil } @@ -5661,13 +5847,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex setHeaderRaw(req.Header, "x-api-key", token) } - // 白名单透传headers(恢复真实 wire casing) - for key, values := range clientHeaders { - lowerKey := strings.ToLower(key) - if allowedHeaders[lowerKey] { - wireKey := resolveWireCasing(key) - for _, v := range values { - addHeaderRaw(req.Header, wireKey, v) + // 白名单透传 headers + // OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。 + // Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。 + // 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent / + // x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。 + if tokenType != "oauth" || !mimicClaudeCode { + for key, values := range clientHeaders { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } } } } @@ -5708,7 +5900,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // Haiku models are exempt from third-party detection and don't need it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} if !strings.Contains(strings.ToLower(modelID), "haiku") { - requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} + requiredBetas = claude.FullClaudeCodeMimicryBetas() } setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) } else { @@ -6222,6 +6414,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { if isStream { setHeaderRaw(req.Header, "x-stainless-helper-method", "stream") } + // Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。 + // 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。 + if getHeaderRaw(req.Header, "x-client-request-id") == "" { + setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString()) + } } func truncateForLog(b []byte, maxBytes int) string { @@ -6987,7 +7184,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http for _, block := range outputBlocks { if !clientDisconnected { - if _, werr := fmt.Fprint(w, block); werr != nil { + restored := reverseToolNamesIfPresent(c, []byte(block)) + if _, werr := fmt.Fprint(w, string(restored)); werr != nil { clientDisconnected = true logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") break @@ -7329,6 +7527,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + body = reverseToolNamesIfPresent(c, body) + // 写入响应 c.Data(resp.StatusCode, contentType, body) @@ -8320,12 +8520,19 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // Pre-filter: strip empty text blocks to prevent upstream 400. body = StripEmptyTextBlocks(body) - isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) - shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + shouldMimicClaudeCode := account.IsOAuth() if shouldMimicClaudeCode { normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + + body = stripMessageCacheControl(body) + body = addMessageCacheBreakpoints(body) + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + } else { + body = applyToolsLastCacheBreakpoint(body) + } } // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 @@ -8749,7 +8956,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con applyClaudeCodeMimicHeaders(req, false) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} + requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting) setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) } else { clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") diff --git a/backend/internal/service/gateway_tool_rewrite.go b/backend/internal/service/gateway_tool_rewrite.go new file mode 100644 index 00000000..c76cab62 --- /dev/null +++ b/backend/internal/service/gateway_tool_rewrite.go @@ -0,0 +1,313 @@ +package service + +import ( + "fmt" + "hash/fnv" + "math/rand" + "sort" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。 +// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。 +const toolNameRewriteKey = "claude_tool_name_rewrite" + +// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py +// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。 +var staticToolNameRewrites = map[string]string{ + "sessions_": "cc_sess_", + "session_": "cc_ses_", +} + +// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。 +// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。 +var fakeToolNamePrefixes = []string{ + "analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_", + "process_", "query_", "render_", "resolve_", "sync_", "update_", + "validate_", "convert_", "extract_", "manage_", "monitor_", "parse_", + "review_", "search_", "transform_", "handle_", "invoke_", "notify_", +} + +// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。 +// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。 +const dynamicToolMapThreshold = 5 + +// ToolNameRewrite 是单次请求内的工具名混淆映射。 +// - Forward: real → fake,请求阶段在 body 上应用。 +// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。 +// +// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的 +// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的 +// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。 +type ToolNameRewrite struct { + Forward map[string]string + Reverse map[string]string + ReverseOrdered [][2]string +} + +// buildDynamicToolMap 构造 tools 的动态假名映射。 +// +// 与 Parrot _build_dynamic_tool_map 语义等价: +// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback) +// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中) +// +// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池; +// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留: +// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。 +// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。 +func buildDynamicToolMap(toolNames []string) map[string]string { + if len(toolNames) <= dynamicToolMapThreshold { + return nil + } + h := fnv.New64a() + for i, n := range toolNames { + if i > 0 { + _, _ = h.Write([]byte{0}) + } + _, _ = h.Write([]byte(n)) + } + rng := rand.New(rand.NewSource(int64(h.Sum64()))) + + available := make([]string, len(fakeToolNamePrefixes)) + copy(available, fakeToolNamePrefixes) + rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] }) + + mapping := make(map[string]string, len(toolNames)) + for i, name := range toolNames { + prefix := available[i%len(available)] + headLen := 3 + if len(name) < 3 { + headLen = len(name) + } + fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i) + mapping[name] = fake + } + return mapping +} + +// sanitizeToolName 把真名转成假名。 +// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。 +func sanitizeToolName(name string, dynamic map[string]string) string { + if dynamic != nil { + if fake, ok := dynamic[name]; ok { + return fake + } + } + for prefix, replacement := range staticToolNameRewrites { + if strings.HasPrefix(name, prefix) { + return replacement + name[len(prefix):] + } + } + return name +} + +// shouldMimicToolName 指示某个 tool 是否需要重命名。 +// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分, +// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。 +func shouldMimicToolName(toolType string) bool { + if toolType == "" || toolType == "function" || toolType == "custom" { + return true + } + return false +} + +// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite +// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。 +// +// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。 +func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return nil + } + + mimicableNames := make([]string, 0) + toolsArr := tools.Array() + for _, t := range toolsArr { + if !shouldMimicToolName(t.Get("type").String()) { + continue + } + name := t.Get("name").String() + if name == "" { + continue + } + mimicableNames = append(mimicableNames, name) + } + + dynamic := buildDynamicToolMap(mimicableNames) + + rw := &ToolNameRewrite{ + Forward: make(map[string]string), + Reverse: make(map[string]string), + } + for _, name := range mimicableNames { + fake := sanitizeToolName(name, dynamic) + if fake == name { + continue + } + rw.Forward[name] = fake + rw.Reverse[fake] = name + } + if len(rw.Forward) == 0 { + return nil + } + + rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse)) + for fake, real := range rw.Reverse { + rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real}) + } + sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool { + return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0]) + }) + + return rw +} + +// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上: +// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool) +// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐, +// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL) +// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool") +// +// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致; +// 响应侧 bytes.Replace 会连带还原它们。 +func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte { + if rw == nil || len(rw.Forward) == 0 { + body = applyToolsLastCacheBreakpoint(body) + return body + } + + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + idx := -1 + tools.ForEach(func(_, t gjson.Result) bool { + idx++ + if !shouldMimicToolName(t.Get("type").String()) { + return true + } + name := t.Get("name").String() + if name == "" { + return true + } + fake, ok := rw.Forward[name] + if !ok { + return true + } + if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil { + body = next + } + return true + }) + } + + if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" { + name := tc.Get("name").String() + if fake, ok := rw.Forward[name]; ok { + if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil { + body = next + } + } + } + + body = applyToolsLastCacheBreakpoint(body) + return body +} + +// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control +// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}` +// 行为,但 ttl 按本仓规则: +// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖 +// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL} +// +// 纯副作用函数,tools 不存在或为空数组时 no-op。 +func applyToolsLastCacheBreakpoint(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return body + } + arr := tools.Array() + if len(arr) == 0 { + return body + } + lastIdx := len(arr) - 1 + existingCC := arr[lastIdx].Get("cache_control") + + if existingCC.Exists() && existingCC.Get("ttl").String() != "" { + return body + } + + if existingCC.Exists() { + if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil { + body = next + } + return body + } + + raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL) + if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil { + body = next + } + return body +} + +// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。 +// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突 +// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。 +// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。 +// +// rw 可为 nil;nil 时仍会做静态前缀还原。 +func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte { + if rw != nil { + for _, pair := range rw.ReverseOrdered { + fake, real := pair[0], pair[1] + if fake == "" || fake == real { + continue + } + data = replaceAllBytes(data, fake, real) + } + } + for prefix, replacement := range staticToolNameRewrites { + data = replaceAllBytes(data, replacement, prefix) + } + return data +} + +// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。 +func replaceAllBytes(data []byte, from, to string) []byte { + if len(data) == 0 || from == to || !strings.Contains(string(data), from) { + return data + } + return []byte(strings.ReplaceAll(string(data), from, to)) +} + +// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。 +// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。 +func toolNameRewriteFromContext(c interface { + Get(string) (any, bool) +}) *ToolNameRewrite { + if c == nil { + return nil + } + raw, ok := c.Get(toolNameRewriteKey) + if !ok || raw == nil { + return nil + } + rw, _ := raw.(*ToolNameRewrite) + return rw +} + +// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping +// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。 +func reverseToolNamesIfPresent(c interface { + Get(string) (any, bool) +}, chunk []byte) []byte { + rw := toolNameRewriteFromContext(c) + if rw == nil && len(staticToolNameRewrites) == 0 { + return chunk + } + return restoreToolNamesInBytes(chunk, rw) +} diff --git a/backend/internal/service/gateway_tool_rewrite_test.go b/backend/internal/service/gateway_tool_rewrite_test.go new file mode 100644 index 00000000..8f0e3939 --- /dev/null +++ b/backend/internal/service/gateway_tool_rewrite_test.go @@ -0,0 +1,185 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) { + // Parrot 行为:tools 数量 ≤ 5 时不做动态映射。 + names := []string{"bash", "edit", "read", "write", "search"} + require.Nil(t, buildDynamicToolMap(names)) +} + +func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) { + // Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。 + names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"} + a := buildDynamicToolMap(names) + b := buildDynamicToolMap(names) + require.NotNil(t, a) + require.Equal(t, a, b, "same input tool_names must yield identical mapping") + require.Len(t, a, 6) + for _, name := range names { + require.Contains(t, a, name) + require.NotEqual(t, name, a[name]) + } +} + +func TestSanitizeToolName_StaticPrefix(t *testing.T) { + require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil)) + require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil)) + require.Equal(t, "bash", sanitizeToolName("bash", nil)) +} + +func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) { + dyn := map[string]string{"sessions_list": "analyze_ses00"} + got := sanitizeToolName("sessions_list", dyn) + require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix") +} + +func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) { + // 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时, + // 长的必须先替换。本测试用显式构造的映射来验证排序不变量。 + rw := &ToolNameRewrite{ + Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"}, + Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"}, + } + // 手工构造 ReverseOrdered:长的在前 + rw.ReverseOrdered = [][2]string{ + {"abc_12_ext", "bar"}, + {"abc_12", "foo"}, + } + data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`) + restored := string(restoreToolNamesInBytes(data, rw)) + require.Equal(t, `{"tool":"bar","other":"foo"}`, restored) +} + +func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) { + data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`) + got := string(restoreToolNamesInBytes(data, nil)) + require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got) +} + +func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) { + body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`) + rw := buildToolNameRewriteFromBody(body) + require.NotNil(t, rw) + require.Contains(t, rw.Forward, "sessions_list") + require.Contains(t, rw.Forward, "session_get") + // web_search is a server tool, not rewritten + require.NotContains(t, rw.Forward, "web_search") + + out := applyToolNameRewriteToBody(body, rw) + + // tools[0].name and tools[1].name rewritten; tools[2].name untouched + require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String()) + require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String()) + require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String()) + + // tool_choice.name rewritten + require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String()) + require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String()) +} + +func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) { + body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`) + out := applyToolsLastCacheBreakpoint(body) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String()) + // First tool untouched + require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists()) +} + +func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) { + body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`) + out := applyToolsLastCacheBreakpoint(body) + // User-provided ttl must be preserved. + require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String()) +} + +func TestStripMessageCacheControl(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`) + out := stripMessageCacheControl(body) + require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists()) +} + +func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + out := addMessageCacheBreakpoints(body) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) { + // Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。 + body := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"q1"}]}, + {"role":"assistant","content":[{"type":"text","text":"a1"}]}, + {"role":"user","content":[{"type":"text","text":"q2"}]}, + {"role":"assistant","content":[{"type":"text","text":"a2"}]} + ]}`) + out := addMessageCacheBreakpoints(body) + // 最后一条 assistant 被打断点 + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String()) + // 倒数第二个 user turn = index 0(唯一另一个 user) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String()) + // 其他不打断点 + require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists()) + require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists()) +} + +func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + out := addMessageCacheBreakpoints(body) + // content 升级成数组 + require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray()) + require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String()) + require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) { + // 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列 + body := []byte(`{"tools":[ + {"name":"t1","input_schema":{}}, + {"name":"t2","input_schema":{}}, + {"name":"t3","input_schema":{}}, + {"name":"t4","input_schema":{}}, + {"name":"t5","input_schema":{}}, + {"name":"t6","input_schema":{}} + ]}`) + rw := buildToolNameRewriteFromBody(body) + require.NotNil(t, rw) + require.NotEmpty(t, rw.ReverseOrdered) + for i := 1; i < len(rw.ReverseOrdered); i++ { + require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]), + "ReverseOrdered must be sorted by fake-name length descending") + } +} + +func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) { + data := []byte("plain text without any tool names") + require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil))) +} + +// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}". +func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) { + names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"} + m := buildDynamicToolMap(names) + require.NotNil(t, m) + for _, name := range names { + fake, ok := m[name] + require.True(t, ok) + // fake = prefix + head3 + "%02d" + // ends with two decimal digits + require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake) + head := name + if len(head) > 3 { + head = head[:3] + } + require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name) + } +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 91d452db..52b89fc8 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -25,21 +25,16 @@ var ( userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) -func defaultIdentityFingerprint() Fingerprint { - profile := claude.DefaultDeviceProfile() - return Fingerprint{ - UserAgent: profile.UserAgent, - StainlessLang: profile.StainlessLang, - StainlessPackageVersion: profile.StainlessPackageVersion, - StainlessOS: profile.StainlessOS, - StainlessArch: profile.StainlessArch, - StainlessRuntime: profile.StainlessRuntime, - StainlessRuntimeVersion: profile.StainlessRuntimeVersion, - } -} - // 默认指纹值(当客户端未提供时使用) -var defaultFingerprint = defaultIdentityFingerprint() +var defaultFingerprint = Fingerprint{ + UserAgent: "claude-cli/2.1.92 (external, cli)", + StainlessLang: "js", + StainlessPackageVersion: "0.70.0", + StainlessOS: "Linux", + StainlessArch: "arm64", + StainlessRuntime: "node", + StainlessRuntimeVersion: "v24.13.0", +} // Fingerprint represents account fingerprint data type Fingerprint struct { diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 808f1229..7a0a6636 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -3,7 +3,6 @@ package service import ( "container/heap" "context" - "errors" "fmt" "hash/fnv" "math" @@ -45,6 +44,7 @@ type OpenAIAccountScheduleRequest struct { RequestedModel string RequiredTransport OpenAIUpstreamTransport RequiredImageCapability OpenAIImagesCapability + RequireCompact bool ExcludedIDs map[int64]struct{} } @@ -258,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select( previousResponseID, req.RequestedModel, req.ExcludedIDs, + req.RequireCompact, ) if err != nil { return nil, decision, err } if selection != nil && selection.Account != nil { if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } selection = nil } } @@ -348,8 +352,8 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel) - if account == nil { + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact) + if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } @@ -590,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( return nil, 0, 0, 0, err } if len(accounts) == 0 { - return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false) } // require_privacy_set: 获取分组信息 @@ -630,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( }) } if len(filtered) == 0 { - return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false) } loadMap := map[int64]*AccountLoadInfo{} @@ -640,45 +644,14 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } } - minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority - maxWaiting := 1 - loadRateSum := 0.0 - loadRateSumSquares := 0.0 - minTTFT, maxTTFT := 0.0, 0.0 - hasTTFTSample := false - candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered)) for _, account := range filtered { loadInfo := loadMap[account.ID] if loadInfo == nil { loadInfo = &AccountLoadInfo{AccountID: account.ID} } - if account.Priority < minPriority { - minPriority = account.Priority - } - if account.Priority > maxPriority { - maxPriority = account.Priority - } - if loadInfo.WaitingCount > maxWaiting { - maxWaiting = loadInfo.WaitingCount - } errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) - if hasTTFT && ttft > 0 { - if !hasTTFTSample { - minTTFT, maxTTFT = ttft, ttft - hasTTFTSample = true - } else { - if ttft < minTTFT { - minTTFT = ttft - } - if ttft > maxTTFT { - maxTTFT = ttft - } - } - } - loadRate := float64(loadInfo.LoadRate) - loadRateSum += loadRate - loadRateSumSquares += loadRate * loadRate - candidates = append(candidates, openAIAccountCandidateScore{ + allCandidates = append(allCandidates, openAIAccountCandidateScore{ account: account, loadInfo: loadInfo, errorRate: errorRate, @@ -686,53 +659,183 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( hasTTFT: hasTTFT, }) } - loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) - weights := s.service.openAIWSSchedulerWeights() - for i := range candidates { - item := &candidates[i] - priorityFactor := 1.0 - if maxPriority > minPriority { - priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + // Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用 + // 时作为最后兜底(snapshot 可能已陈旧)。 + candidates := allCandidates + staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates)) + if req.RequireCompact { + candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates)) + for _, candidate := range allCandidates { + if openAICompactSupportTier(candidate.account) == 0 { + staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate) + continue + } + candidates = append(candidates, candidate) } - loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) - queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) - errorFactor := 1 - clamp01(item.errorRate) - ttftFactor := 0.5 - if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { - ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 { + return nil, 0, 0, 0, ErrNoAvailableCompactAccounts } - - item.score = weights.Priority*priorityFactor + - weights.Load*loadFactor + - weights.Queue*queueFactor + - weights.ErrorRate*errorFactor + - weights.TTFT*ttftFactor } - topK := s.service.openAIWSLBTopK() - if topK > len(candidates) { - topK = len(candidates) - } - if topK <= 0 { - topK = 1 - } - rankedCandidates := selectTopKOpenAICandidates(candidates, topK) - selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + candidateCount := len(candidates) + loadSkew := 0.0 + if len(candidates) > 0 { + minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + for _, candidate := range candidates { + if candidate.account.Priority < minPriority { + minPriority = candidate.account.Priority + } + if candidate.account.Priority > maxPriority { + maxPriority = candidate.account.Priority + } + if candidate.loadInfo.WaitingCount > maxWaiting { + maxWaiting = candidate.loadInfo.WaitingCount + } + if candidate.hasTTFT && candidate.ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = candidate.ttft, candidate.ttft + hasTTFTSample = true + } else { + if candidate.ttft < minTTFT { + minTTFT = candidate.ttft + } + if candidate.ttft > maxTTFT { + maxTTFT = candidate.ttft + } + } + } + loadRate := float64(candidate.loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + } + loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + } + + topK := 0 + if len(candidates) > 0 { + topK = s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + } + + buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 || topK <= 0 { + return nil + } + groupTopK := topK + if groupTopK > len(pool) { + groupTopK = len(pool) + } + ranked := selectTopKOpenAICandidates(pool, groupTopK) + return buildOpenAIWeightedSelectionOrder(ranked, req) + } + sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 { + return nil + } + ordered := append([]openAIAccountCandidateScore(nil), pool...) + sort.SliceStable(ordered, func(i, j int) bool { + a, b := ordered[i], ordered[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount { + return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + return ordered + } + + selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates)) + if req.RequireCompact { + supported := make([]openAIAccountCandidateScore, 0, len(candidates)) + unknown := make([]openAIAccountCandidateScore, 0, len(candidates)) + for _, candidate := range candidates { + switch openAICompactSupportTier(candidate.account) { + case 2: + supported = append(supported, candidate) + case 1: + unknown = append(unknown, candidate) + } + } + if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil { + return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts + } + selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...) + selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...) + if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil { + selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...) + } + } else { + selectionOrder = buildSelectionOrder(candidates) + } + if len(selectionOrder) == 0 { + return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0) + } + + compactBlocked := false for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } + if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { + compactBlocked = true + continue + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { - return nil, len(candidates), topK, loadSkew, acquireErr + return nil, candidateCount, topK, loadSkew, acquireErr } if result != nil && result.Acquired { if req.SessionHash != "" { @@ -742,17 +845,25 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, - }, len(candidates), topK, loadSkew, nil + }, candidateCount, topK, loadSkew, nil } } cfg := s.service.schedulingConfig() // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + continue + } + if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { + compactBlocked = true + continue + } return &AccountSelectionResult{ Account: fresh, WaitPlan: &AccountWaitPlan{ @@ -761,10 +872,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, - }, len(candidates), topK, loadSkew, nil + }, candidateCount, topK, loadSkew, nil } - return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts + return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked) } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { @@ -905,8 +1016,9 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requestedModel string, excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, + requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "") + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact) } func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( @@ -917,13 +1029,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( excludedIDs map[int64]struct{}, requiredCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false) if err == nil && selection != nil && selection.Account != nil { return selection, decision, nil } // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) if requiredCapability == OpenAIImagesCapabilityNative { - return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic) + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false) } return selection, decision, err } @@ -937,6 +1049,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, requiredImageCapability OpenAIImagesCapability, + requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} scheduler := s.getOpenAIAccountScheduler(ctx) @@ -945,7 +1058,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) if err != nil { return nil, decision, err } @@ -970,7 +1083,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) if err != nil { return nil, decision, err } @@ -1008,6 +1121,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( RequestedModel: requestedModel, RequiredTransport: requiredTransport, RequiredImageCapability: requiredImageCapability, + RequireCompact: requireCompact, ExcludedIDs: excludedIDs, }) } diff --git a/backend/internal/service/openai_account_scheduler_compact_test.go b/backend/internal/service/openai_account_scheduler_compact_test.go new file mode 100644 index 00000000..f7e08a20 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_compact_test.go @@ -0,0 +1,195 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown +// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。 +func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(91001) + accounts := []Account{ + { + ID: 71001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{}, // unknown + }, + { + ID: 71002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{"openai_compact_supported": true}, // tier=2 + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.4", + nil, + OpenAIUpstreamTransportAny, + true, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown") +} + +// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported +// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。 +func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(91002) + accounts := []Account{ + { + ID: 71010, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff}, + }, + { + ID: 71011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{"openai_compact_supported": false}, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.4", + nil, + OpenAIUpstreamTransportAny, + true, + ) + require.Error(t, err) + require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error") + require.Nil(t, selection) +} + +// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown +// 验证当没有"已知支持"账号时,compact 请求会回退到"未探测"账号。 +func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(91003) + accounts := []Account{ + { + ID: 71020, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{"openai_compact_supported": false}, // tier=0 + }, + { + ID: 71021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{}, // unknown -> tier=1 + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.4", + nil, + OpenAIUpstreamTransportAny, + true, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available") +} + +// TestOpenAICompactSupportTier 验证 tier 分类逻辑。 +func TestOpenAICompactSupportTier(t *testing.T) { + tests := []struct { + name string + account *Account + want int + }{ + {name: "nil", account: nil, want: 0}, + {name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0}, + {name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1}, + {name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2}, + {name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0}, + {name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2}, + {name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := openAICompactSupportTier(tt.account); got != tt.want { + t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index b02370cb..0950ee54 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -289,6 +289,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -343,6 +344,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require "gpt-5.1", nil, OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -384,6 +386,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require "gpt-5.1", nil, OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) require.ErrorContains(t, err, "no available OpenAI accounts") require.Nil(t, selection) @@ -445,6 +448,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -486,7 +490,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } - selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false) require.NoError(t, err) require.NotNil(t, selection) require.NotNil(t, selection.Account) @@ -540,7 +544,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } - selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false) require.NoError(t, err) require.NotNil(t, selection) require.NotNil(t, selection.Account) @@ -616,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -662,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -740,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -788,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -857,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick "gpt-5.1", nil, OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -900,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl "gpt-5.1", nil, OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) require.Error(t, err) require.Nil(t, selection) @@ -976,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) @@ -1014,7 +1025,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } - selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false) require.NoError(t, err) require.NotNil(t, selection) svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120)) @@ -1218,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA "gpt-5.1", nil, OpenAIUpstreamTransportAny, + false, ) require.NoError(t, err) require.NotNil(t, selection) diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index ddafc6eb..8d63e68e 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -54,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh "gpt-5.1", nil, OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) require.NoError(t, err) require.NotNil(t, selection) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 14abde9b..e765d7e9 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "fmt" "strings" ) @@ -48,6 +49,8 @@ type codexTransformResult struct { const ( codexImageGenerationBridgeMarker = "" codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n" + codexSparkImageUnsupportedMarker = "" + codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n" ) func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { @@ -151,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if normalizeCodexTools(reqBody) { result.Modified = true } + if normalizeCodexToolChoice(reqBody) { + result.Modified = true + } if v, ok := reqBody["prompt_cache_key"].(string); ok { result.PromptCacheKey = strings.TrimSpace(v) @@ -165,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if applyInstructions(reqBody, isCodexCLI) { result.Modified = true } + if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) { + result.Modified = true + } // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 if input, ok := reqBody["input"].([]any); ok { + if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified { + input = normalizedInput + result.Modified = true + } + if normalizedInput, modified := normalizeCodexMessageContentText(input); modified { + input = normalizedInput + result.Modified = true + } input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true @@ -192,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact return result } +func normalizeCodexToolChoice(reqBody map[string]any) bool { + choice, ok := reqBody["tool_choice"] + if !ok || choice == nil { + return false + } + choiceMap, ok := choice.(map[string]any) + if !ok { + return false + } + choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) + if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { + return false + } + reqBody["tool_choice"] = "auto" + return true +} + +func codexToolsContainType(rawTools any, toolType string) bool { + tools, ok := rawTools.([]any) + if !ok || strings.TrimSpace(toolType) == "" { + return false + } + for _, rawTool := range tools { + tool, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType { + return true + } + } + return false +} + +func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { + if len(input) == 0 { + return input, false + } + + modified := false + normalized := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + normalized = append(normalized, item) + continue + } + role, _ := m["role"].(string) + if strings.TrimSpace(role) != "tool" { + normalized = append(normalized, item) + continue + } + + callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"]) + callID = strings.TrimSpace(callID) + if callID == "" { + // Responses does not accept role:"tool". If no call id is available, + // preserve the text as a user message instead of sending invalid input. + fallback := make(map[string]any, len(m)) + for key, value := range m { + fallback[key] = value + } + fallback["role"] = "user" + delete(fallback, "tool_call_id") + normalized = append(normalized, fallback) + modified = true + continue + } + + output := extractTextFromContent(m["content"]) + if output == "" { + if value, ok := m["output"].(string); ok { + output = value + } + } + if output == "" && m["content"] != nil { + if b, err := json.Marshal(m["content"]); err == nil { + output = string(b) + } + } + + normalized = append(normalized, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": output, + }) + modified = true + } + if !modified { + return input, false + } + return normalized, true +} + +func normalizeCodexMessageContentText(input []any) ([]any, bool) { + if len(input) == 0 { + return input, false + } + + modified := false + normalized := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" { + normalized = append(normalized, item) + continue + } + parts, ok := m["content"].([]any) + if !ok { + normalized = append(normalized, item) + continue + } + + var newItem map[string]any + var newParts []any + ensureItemCopy := func() { + if newItem != nil { + return + } + newItem = make(map[string]any, len(m)) + for key, value := range m { + newItem[key] = value + } + newParts = make([]any, len(parts)) + copy(newParts, parts) + } + + for i, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + text, hasText := part["text"] + if !hasText { + continue + } + if _, ok := text.(string); ok { + continue + } + + ensureItemCopy() + newPart := make(map[string]any, len(part)) + for key, value := range part { + newPart[key] = value + } + newPart["text"] = stringifyCodexContentText(text) + newParts[i] = newPart + modified = true + } + + if newItem != nil { + newItem["content"] = newParts + normalized = append(normalized, newItem) + continue + } + normalized = append(normalized, item) + } + if !modified { + return input, false + } + return normalized, true +} + +func stringifyCodexContentText(value any) string { + switch v := value.(type) { + case string: + return v + case nil: + return "" + default: + if b, err := json.Marshal(v); err == nil { + return string(b) + } + return fmt.Sprint(v) + } +} + func normalizeCodexModel(model string) string { model = strings.TrimSpace(model) if model == "" { @@ -244,6 +438,10 @@ func normalizeCodexModel(model string) string { return "gpt-5.4" } +func isCodexSparkModel(model string) bool { + return normalizeCodexModel(model) == "gpt-5.3-codex-spark" +} + func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -265,6 +463,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { return false } +func hasOpenAIInputImage(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"]) +} + +func hasOpenAIInputImageValue(value any) bool { + switch v := value.(type) { + case []any: + for _, item := range v { + if hasOpenAIInputImageValue(item) { + return true + } + } + case map[string]any: + if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" { + return true + } + if _, ok := v["image_url"]; ok { + return true + } + return hasOpenAIInputImageValue(v["content"]) + } + return false +} + +func validateCodexSparkInput(reqBody map[string]any, model string) error { + if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) { + return nil + } + return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model)) +} + func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -309,6 +541,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool { if len(reqBody) == 0 { return false } + if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) { + return false + } tool := map[string]any{ "type": "image_generation", @@ -344,6 +579,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) { return false } + if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) { + return false + } existing, _ := reqBody["instructions"].(string) if strings.Contains(existing, codexImageGenerationBridgeMarker) { @@ -360,6 +598,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { return true } +func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + existing, _ := reqBody["instructions"].(string) + if strings.Contains(existing, codexSparkImageUnsupportedMarker) { + return false + } + existing = strings.TrimRight(existing, " \t\r\n") + if strings.TrimSpace(existing) == "" { + reqBody["instructions"] = codexSparkImageUnsupportedText + return true + } + reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText + return true +} + func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { if !hasOpenAIImageGenerationTool(reqBody) { return nil @@ -658,12 +913,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } } + if !isCodexToolCallItemType(typ) { + ensureCopy() + delete(newItem, "call_id") + } + + if codexInputItemRequiresName(typ) { + if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" { + name := firstNonEmptyString(m["tool_name"]) + if name == "" { + if function, ok := m["function"].(map[string]any); ok { + name = firstNonEmptyString(function["name"]) + } + } + if name == "" { + name = "tool" + } + ensureCopy() + newItem["name"] = name + } + } + if !preserveReferences { ensureCopy() delete(newItem, "id") - if !isCodexToolCallItemType(typ) { - delete(newItem, "call_id") - } } filtered = append(filtered, newItem) @@ -672,10 +945,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } func isCodexToolCallItemType(typ string) bool { - if typ == "" { + switch typ { + case "function_call", + "tool_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "mcp_tool_call", + "function_call_output", + "mcp_tool_call_output", + "custom_tool_call_output", + "tool_search_output": + return true + default: + return false + } +} + +func codexInputItemRequiresName(typ string) bool { + switch strings.TrimSpace(typ) { + case "function_call", "custom_tool_call", "mcp_tool_call": + return true + default: return false } - return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output") } func normalizeCodexTools(reqBody map[string]any) bool { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 4fd16fdb..75f5c55c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -92,6 +92,235 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly require.Equal(t, "fc1", second["call_id"]) } +func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"}, + }, + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "tool_search_output", first["type"]) + require.Equal(t, "fc1", first["call_id"]) +} + +func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"}, + map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"}, + }, + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "fccustom", first["call_id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "fcmcp", second["call_id"]) +} + +func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"}, + map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "ig_123", first["id"]) + _, hasCallID := first["call_id"] + require.False(t, hasCallID) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + _, hasCallID = second["call_id"] + require.False(t, hasCallID) +} + +func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "message", + "role": "tool", + "tool_call_id": "call_1", + "content": "ok", + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call_output", item["type"]) + require.Equal(t, "fc1", item["call_id"]) + require.Equal(t, "ok", item["output"]) + _, hasRole := item["role"] + require.False(t, hasRole) +} + +func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "message", + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": []any{"a", "b"}}, + }, + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + item, ok := input[0].(map[string]any) + require.True(t, ok) + content, ok := item["content"].([]any) + require.True(t, ok) + part, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, `["a","b"]`, part["text"]) +} + +func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{"type": "custom"}, + } + + applyCodexOAuthTransform(reqBody, true, false) + + require.Equal(t, "auto", reqBody["tool_choice"]) +} + +func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "custom", "name": "shell"}, + }, + "tool_choice": map[string]any{"type": "custom"}, + } + + applyCodexOAuthTransform(reqBody, true, false) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "custom", choice["type"]) +} + +func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{"type": "message", "role": "user", "content": "run tool"}, + map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + item, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call", item["type"]) + require.Equal(t, "tool", item["name"]) + require.Equal(t, "fc1", item["call_id"]) +} + +func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "shell", item["name"]) + require.Equal(t, "fc1", item["call_id"]) +} + +func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "mcp_tool_call", + "call_id": "call_abc", + "name": "remote_tool", + "arguments": "{}", + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "mcp_tool_call", item["type"]) + require.Equal(t, "remote_tool", item["name"]) + require.Equal(t, "fcabc", item["call_id"]) +} + +func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) { + for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} { + require.True(t, codexInputItemRequiresName(typ), typ) + require.True(t, isCodexToolCallItemType(typ), typ) + } +} + func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 @@ -261,6 +490,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) { require.Equal(t, "png", tool["output_format"]) } +func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": "draw a cat", + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.False(t, modified) + require.NotContains(t, reqBody, "tools") +} + func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.4", @@ -306,6 +546,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t * func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) { reqBody := map[string]any{ + "model": "gpt-5.4", "instructions": "existing instructions", "tools": []any{ map[string]any{"type": "image_generation", "output_format": "png"}, @@ -325,6 +566,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin require.False(t, modified) } +func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "png"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) + require.Equal(t, "existing instructions", reqBody["instructions"]) +} + func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) { reqBody := map[string]any{ "instructions": "existing instructions", @@ -338,6 +593,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te require.Equal(t, "existing instructions", reqBody["instructions"]) } +func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "describe"}, + map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="}, + }, + }, + }, + } + + err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark") + require.Error(t, err) + require.Contains(t, err.Error(), "does not support image input") +} + +func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}}, + }, + }, + }, + } + + err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark") + require.Error(t, err) +} + +func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "hello"}, + }, + }, + }, + } + + require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")) +} + +func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "instructions": "existing instructions", + "input": "hello", + } + + result := applyCodexOAuthTransform(reqBody, true, false) + require.True(t, result.Modified) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Contains(t, instructions, "existing instructions") + require.Contains(t, instructions, codexSparkImageUnsupportedMarker) + require.Contains(t, instructions, "does not support image generation") + require.Contains(t, instructions, "switch to a non-Spark Codex model") + require.NotContains(t, instructions, codexImageGenerationBridgeMarker) +} + +func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "instructions": "existing instructions", + "input": "hello", + } + + applyCodexOAuthTransform(reqBody, true, false) + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotContains(t, instructions, codexSparkImageUnsupportedMarker) +} + func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) { reqBody := map[string]any{ "model": "gpt-image-2", diff --git a/backend/internal/service/openai_compact_model_mapping_test.go b/backend/internal/service/openai_compact_model_mapping_test.go new file mode 100644 index 00000000..fc408e64 --- /dev/null +++ b/backend/internal/service/openai_compact_model_mapping_test.go @@ -0,0 +1,135 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"}, + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-5.4", result.Model) + require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel) + require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String()) +} + +func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 2, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"}, + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-5.4", result.Model) + require.Equal(t, "gpt-5.4", result.UpstreamModel) + require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) +} + +func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 3, + Name: "openai-oauth-pass", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"}, + }, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-5.4", result.Model) + require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel) + require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String()) +} diff --git a/backend/internal/service/openai_compact_probe.go b/backend/internal/service/openai_compact_probe.go new file mode 100644 index 00000000..e8deff2d --- /dev/null +++ b/backend/internal/service/openai_compact_probe.go @@ -0,0 +1,120 @@ +package service + +import ( + "net/http" + "strconv" + "strings" + "time" +) + +const ( + // AccountTestModeDefault drives the standard /responses connection test. + AccountTestModeDefault = "default" + // AccountTestModeCompact drives the /responses/compact compact-probe test. + AccountTestModeCompact = "compact" +) + +func normalizeAccountTestMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case AccountTestModeCompact: + return AccountTestModeCompact + default: + return AccountTestModeDefault + } +} + +func createOpenAICompactProbePayload(model string) map[string]any { + return map[string]any{ + "model": strings.TrimSpace(model), + "instructions": "You are a helpful coding assistant.", + "input": []any{ + map[string]any{ + "type": "message", + "role": "user", + "content": "Respond with OK.", + }, + }, + } +} + +func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool { + switch status { + case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented: + return true + case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity: + lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body))) + if strings.Contains(lower, "compact") { + for _, keyword := range []string{ + "unsupported", + "not support", + "does not support", + "not available", + "disabled", + } { + if strings.Contains(lower, keyword) { + return true + } + } + } + } + return false +} + +func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any { + updates := map[string]any{ + "openai_compact_checked_at": now.Format(time.RFC3339), + "openai_compact_last_status": nil, + } + + if resp != nil { + updates["openai_compact_last_status"] = resp.StatusCode + } + + switch { + case probeErr != nil: + updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048) + case resp == nil: + updates["openai_compact_last_error"] = "compact probe failed" + default: + errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if errMsg == "" && len(body) > 0 { + errMsg = strings.TrimSpace(string(body)) + } + if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + errMsg = "HTTP " + strconv.Itoa(resp.StatusCode) + } + errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048) + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + updates["openai_compact_supported"] = true + updates["openai_compact_last_error"] = "" + } else { + if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) { + updates["openai_compact_supported"] = false + } + updates["openai_compact_last_error"] = errMsg + } + } + + return updates +} + +func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any { + if len(base) == 0 && len(more) == 0 { + return nil + } + out := make(map[string]any, len(base)+len(more)) + for key, value := range base { + out[key] = value + } + for key, value := range more { + out[key] = value + } + return out +} + +func compactProbeSessionID(accountID int64) string { + if accountID <= 0 { + return "probe_compact" + } + return "probe_compact_" + strconv.FormatInt(accountID, 10) +} diff --git a/backend/internal/service/openai_compact_probe_test.go b/backend/internal/service/openai_compact_probe_test.go new file mode 100644 index 00000000..fe3ba0e8 --- /dev/null +++ b/backend/internal/service/openai_compact_probe_test.go @@ -0,0 +1,122 @@ +package service + +import ( + "errors" + "net/http" + "testing" + "time" +) + +func TestNormalizeAccountTestMode(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "", want: AccountTestModeDefault}, + {input: "default", want: AccountTestModeDefault}, + {input: " compact ", want: AccountTestModeCompact}, + {input: "COMPACT", want: AccountTestModeCompact}, + {input: "unknown", want: AccountTestModeDefault}, + } + + for _, tt := range tests { + if got := normalizeAccountTestMode(tt.input); got != tt.want { + t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now) + + if got := updates["openai_compact_supported"]; got != true { + t.Fatalf("openai_compact_supported = %v, want true", got) + } + if got := updates["openai_compact_last_status"]; got != http.StatusOK { + t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK) + } + if got := updates["openai_compact_last_error"]; got != "" { + t.Fatalf("openai_compact_last_error = %v, want empty string", got) + } + if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) { + t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339)) + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + body := []byte(`404 page not found`) + updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now) + + if got := updates["openai_compact_supported"]; got != false { + t.Fatalf("openai_compact_supported = %v, want false", got) + } + if got := updates["openai_compact_last_status"]; got != http.StatusNotFound { + t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound) + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now) + + if _, exists := updates["openai_compact_supported"]; exists { + t.Fatalf("did not expect openai_compact_supported for 502 response") + } + if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway { + t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway) + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now) + + if _, exists := updates["openai_compact_supported"]; exists { + t.Fatalf("did not expect openai_compact_supported for request error") + } + if got, exists := updates["openai_compact_last_status"]; !exists || got != nil { + t.Fatalf("openai_compact_last_status = %v, want nil key", got) + } + if got := updates["openai_compact_last_error"]; got == "" { + t.Fatalf("expected openai_compact_last_error to be populated") + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now) + + if got, exists := updates["openai_compact_last_status"]; !exists || got != nil { + t.Fatalf("openai_compact_last_status = %v, want nil key", got) + } + if got := updates["openai_compact_last_error"]; got != "compact probe failed" { + t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got) + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`) + updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now) + + if _, exists := updates["openai_compact_supported"]; exists { + t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics") + } + if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest { + t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest) + } +} + +func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) { + now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC) + updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now) + + if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable { + t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable) + } + if got := updates["openai_compact_last_error"]; got != "HTTP 503" { + t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got) + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d99cd7da..379ebe0b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -40,7 +40,7 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL - codexCLIUserAgent = "codex_cli_rs/0.104.0" + codexCLIUserAgent = "codex_cli_rs/0.125.0" // codex_cli_only 拒绝时单个请求头日志长度上限(字符) codexCLIOnlyHeaderValueMaxBytes = 256 @@ -54,7 +54,7 @@ const ( openAIWSRetryBackoffMaxDefault = 2 * time.Second openAIWSRetryJitterRatioDefault = 0.2 openAICompactSessionSeedKey = "openai_compact_session_seed" - codexCLIVersion = "0.104.0" + codexCLIVersion = "0.125.0" // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 openAICodexSnapshotPersistMinInterval = 30 * time.Second ) @@ -306,6 +306,10 @@ func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) +// ErrNoAvailableCompactAccounts indicates the request needs /responses/compact +// support but no compatible account is available. +var ErrNoAvailableCompactAccounts = errors.New("no available OpenAI accounts support /responses/compact") + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -442,11 +446,11 @@ func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Contex return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) } -func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { +func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string, requireCompact bool) bool { if s.channelService == nil { return false } - upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + upstreamModel := resolveOpenAIAccountUpstreamModelForRequest(account, requestedModel, requireCompact) if upstreamModel == "" { return false } @@ -1208,10 +1212,94 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0) } -func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { +// noAvailableOpenAISelectionError builds the standard "no account available" error +// while preserving the compact-specific error when applicable. +func noAvailableOpenAISelectionError(requestedModel string, compactBlocked bool) error { + if compactBlocked { + return ErrNoAvailableCompactAccounts + } + if requestedModel != "" { + return fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) + } + return errors.New("no available OpenAI accounts") +} + +// openAICompactSupportTier classifies an OpenAI account by compact capability. +// 0 = explicitly unsupported, 1 = unknown / not yet probed, 2 = explicitly supported. +func openAICompactSupportTier(account *Account) int { + if account == nil || !account.IsOpenAI() { + return 0 + } + supported, known := account.OpenAICompactSupportKnown() + if !known { + return 1 + } + if supported { + return 2 + } + return 0 +} + +// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / +// compact-support checks used during account selection. +func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool { + if account == nil || !account.IsSchedulable() || !account.IsOpenAI() { + return false + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return false + } + if requireCompact && openAICompactSupportTier(account) == 0 { + return false + } + return true +} + +// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known +// compact support are tried first, followed by unknown, then explicitly unsupported. +// The relative order within each tier is preserved. +func prioritizeOpenAICompactAccounts(accounts []*Account) []*Account { + if len(accounts) == 0 { + return nil + } + supported := make([]*Account, 0, len(accounts)) + unknown := make([]*Account, 0, len(accounts)) + unsupported := make([]*Account, 0, len(accounts)) + for _, account := range accounts { + switch openAICompactSupportTier(account) { + case 2: + supported = append(supported, account) + case 1: + unknown = append(unknown, account) + default: + unsupported = append(unsupported, account) + } + } + out := make([]*Account, 0, len(accounts)) + out = append(out, supported...) + out = append(out, unknown...) + out = append(out, unsupported...) + return out +} + +// resolveOpenAIAccountUpstreamModelForRequest resolves the upstream model that +// would be sent for a given request, honouring compact-only mappings when the +// caller is on the /responses/compact path. +func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedModel string, requireCompact bool) string { + upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + if upstreamModel == "" { + return "" + } + if requireCompact { + return resolveOpenAICompactForwardModel(account, upstreamModel) + } + return upstreamModel +} + +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1221,7 +1309,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil { return account, nil } @@ -1234,13 +1322,10 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs) + selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact) if selected == nil { - if requestedModel != "" { - return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) - } - return nil, errors.New("no available OpenAI accounts") + return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) } // 4. 设置粘性会话绑定 @@ -1257,7 +1342,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account { if sessionHash == "" { return nil } @@ -1289,19 +1374,16 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 验证账号是否可用于当前请求 // Verify account is usable for current request - if !account.IsSchedulable() || !account.IsOpenAI() { + if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { return nil } - if requestedModel != "" && !account.IsModelSupported(requestedModel) { - return nil - } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && - s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } @@ -1316,9 +1398,13 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 返回 nil 表示无可用账号。 // // selectBestAccount selects the best account from candidates (priority + LRU). -// Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +// Returns nil if no available account. The second return reports whether at +// least one candidate was filtered out solely because it lacks compact support +// (only meaningful when requireCompact=true). +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) { var selected *Account + selectedCompactTier := -1 + compactBlocked := false needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) for i := range accounts { @@ -1330,31 +1416,50 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i continue } - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false) if fresh == nil { continue } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { continue } + compactTier := 0 + if requireCompact { + compactTier = openAICompactSupportTier(fresh) + if compactTier == 0 { + compactBlocked = true + continue + } + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used if selected == nil { selected = fresh + selectedCompactTier = compactTier + continue + } + + // compact 模式下高 tier 优先;同 tier 内才比较 priority/LRU。 + if requireCompact && compactTier != selectedCompactTier { + if compactTier > selectedCompactTier { + selected = fresh + selectedCompactTier = compactTier + } continue } if s.isBetterAccount(fresh, selected) { selected = fresh + selectedCompactTier = compactTier } } - return selected + return selected, compactBlocked } // isBetterAccount 判断 candidate 是否比 current 更优。 @@ -1392,6 +1497,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false) +} + +func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1408,7 +1517,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID) if err != nil { return nil, err } @@ -1461,12 +1570,11 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if clearSticky { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } - if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && - (requestedModel == "" || account.IsModelSupported(requestedModel)) { - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) { + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) @@ -1491,6 +1599,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } // ============ Layer 2: Load-aware selection ============ + baseCandidateCount := 0 candidates := make([]*Account, 0, len(accounts)) for i := range accounts { acc := &accounts[i] @@ -1506,9 +1615,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) { continue } + baseCandidateCount++ candidates = append(candidates, acc) } @@ -1528,12 +1638,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if err != nil { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, false) + if requireCompact { + ordered = prioritizeOpenAICompactAccounts(ordered) + } for _, acc := range ordered { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) if fresh == nil { continue } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { continue } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) @@ -1581,12 +1698,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex }) shuffleWithinSortGroups(available) - for _, item := range available { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel) + selectionOrder := make([]accountWithLoad, 0, len(available)) + if requireCompact { + appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { + for _, item := range available { + if openAICompactSupportTier(item.account) == tier { + out = append(out, item) + } + } + return out + } + selectionOrder = appendTier(selectionOrder, 2) + selectionOrder = appendTier(selectionOrder, 1) + // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 + // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 + selectionOrder = appendTier(selectionOrder, 0) + } else { + selectionOrder = append(selectionOrder, available...) + } + + for _, item := range selectionOrder { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) if fresh == nil { continue } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { continue } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) @@ -1602,12 +1742,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 3: Fallback wait ============ sortAccountsByPriorityAndLastUsed(candidates, false) + if requireCompact { + candidates = prioritizeOpenAICompactAccounts(candidates) + } for _, acc := range candidates { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) if fresh == nil { continue } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { continue } return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{ @@ -1618,6 +1765,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex }) } + if requireCompact && baseCandidateCount > 0 { + return nil, ErrNoAvailableCompactAccounts + } return nil, ErrNoAvailableAccounts } @@ -1648,7 +1798,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account { +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { if account == nil { return nil } @@ -1662,20 +1812,20 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. fresh = current } - if !fresh.IsSchedulable() || !fresh.IsOpenAI() { - return nil - } - if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { + if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { return nil } return fresh } -func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account { +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { if account == nil { return nil } if s.schedulerSnapshot == nil || s.accountRepo == nil { + if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) { + return nil + } return account } @@ -1683,10 +1833,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if err != nil || latest == nil { return nil } - if !latest.IsSchedulable() || !latest.IsOpenAI() { - return nil - } - if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { return nil } return latest @@ -1995,18 +2142,47 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco account.Type, ) } + if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "input", + }, + }) + return nil, err + } + + // Compact-only model 映射:仅在 /responses/compact 路径生效,且优先级高于 + // OAuth 模型规范化(避免 OAuth 规范化覆盖 compact-only 自定义模型)。 + isCompactRequest := isOpenAIResponsesCompactPath(c) + compactMapped := false + if isCompactRequest { + compactMappedModel := resolveOpenAICompactForwardModel(account, billingModel) + if compactMappedModel != "" && compactMappedModel != billingModel { + compactMapped = true + upstreamModel = compactMappedModel + reqBody["model"] = compactMappedModel + bodyModified = true + markPatchSet("model", compactMappedModel) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Compact model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", billingModel, compactMappedModel, account.Name, isCodexCLI) + } + } // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, // 以兼容自定义 base_url 的 OpenAI-compatible 上游。 if model, ok := reqBody["model"].(string); ok { - upstreamModel = normalizeOpenAIModelForUpstream(account, model) - if upstreamModel != "" && upstreamModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, upstreamModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = upstreamModel - bodyModified = true - markPatchSet("model", upstreamModel) + if !compactMapped { + upstreamModel = normalizeOpenAIModelForUpstream(account, model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel + bodyModified = true + markPatchSet("model", upstreamModel) + } } // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 @@ -2029,7 +2205,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c)) + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) if codexResult.Modified { bodyModified = true disablePatch() @@ -2504,6 +2680,19 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( reqStream bool, startTime time.Time, ) (*OpenAIForwardResult, error) { + upstreamPassthroughModel := "" + if isOpenAIResponsesCompactPath(c) { + compactMappedModel := resolveOpenAICompactForwardModel(account, reqModel) + if compactMappedModel != "" && compactMappedModel != reqModel { + nextBody, setErr := sjson.SetBytes(body, "model", compactMappedModel) + if setErr != nil { + return nil, fmt.Errorf("set compact passthrough model: %w", setErr) + } + body = nextBody + upstreamPassthroughModel = compactMappedModel + } + } + if account != nil && account.Type == AccountTypeOAuth { if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" @@ -2629,14 +2818,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( var usage *OpenAIUsage var firstTokenMs *int if reqStream { - result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime) + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) if err != nil { return nil, err } usage = result.usage firstTokenMs = result.firstTokenMs } else { - usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c) + usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) if err != nil { return nil, err } @@ -2654,6 +2843,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, + UpstreamModel: upstreamPassthroughModel, ServiceTier: extractOpenAIServiceTierFromBody(body), ReasoningEffort: reasoningEffort, Stream: reqStream, @@ -2957,12 +3147,121 @@ type openaiStreamingResultPassthrough struct { firstTokenMs *int } +func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { + if localStarted { + return true + } + return c != nil && c.Writer != nil && c.Writer.Written() +} + +func openAIStreamEventIsPreamble(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.created", "response.in_progress": + return true + default: + return false + } +} + +func openAIStreamDataStartsClientOutput(data, eventType string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if strings.TrimSpace(eventType) == "response.failed" { + return false + } + return !openAIStreamEventIsPreamble(eventType) +} + +func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool { + code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String())) + if code == "" { + code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String())) + } + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String())) + if errType == "" { + errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String())) + } + combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType)) + if combined == "" { + return true + } + nonRetryableMarkers := []string{ + "invalid_request", + "content_policy", + "policy", + "safety", + "high-risk cyber", + "not allowed", + "violat", + } + for _, marker := range nonRetryableMarkers { + if strings.Contains(combined, marker) { + return false + } + } + return true +} + +func (s *OpenAIGatewayService) newOpenAIStreamFailoverError( + c *gin.Context, + account *Account, + passthrough bool, + upstreamRequestID string, + payload []byte, + message string, +) *UpstreamFailoverError { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "OpenAI stream disconnected before completion" + } + detail := "" + if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + detail = truncateString(string(payload), maxBytes) + } + if c != nil { + setOpsUpstreamError(c, http.StatusBadGateway, message, detail) + event := OpsUpstreamErrorEvent{ + Platform: PlatformOpenAI, + UpstreamStatusCode: http.StatusBadGateway, + UpstreamRequestID: strings.TrimSpace(upstreamRequestID), + Passthrough: passthrough, + Kind: "failover", + Message: message, + Detail: detail, + } + if account != nil { + event.Platform = account.Platform + event.AccountID = account.ID + event.AccountName = account.Name + } + appendOpsUpstreamError(c, event) + } + body, _ := json.Marshal(gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + } +} + func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, + originalModel string, + mappedModel string, ) (*openaiStreamingResultPassthrough, error) { writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) @@ -2986,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( clientDisconnected := false sawDone := false sawTerminalEvent := false + sawFailedEvent := false + failedMessage := "" + clientOutputStarted := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + pendingLines := make([]string, 0, 8) + writePendingLines := func() bool { + for _, pending := range pendingLines { + if _, err := fmt.Fprintln(w, pending); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + return false + } + } + pendingLines = pendingLines[:0] + return true + } scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -2997,18 +3311,40 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( scanner.Buffer(scanBuf[:0], maxLineSize) defer putSSEScannerBuf64K(scanBuf) + needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) + for scanner.Scan() { line := scanner.Text() + lineStartsClientOutput := false + forceFlushFailedEvent := false if data, ok := extractOpenAISSEDataLine(line); ok { dataBytes := []byte(data) trimmedData := strings.TrimSpace(data) + if needModelReplace && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + if replacedData, replaced := extractOpenAISSEDataLine(line); replaced { + dataBytes = []byte(replacedData) + trimmedData = strings.TrimSpace(replacedData) + } + } + eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String()) + if eventType == "response.failed" { + failedMessage = extractOpenAISSEErrorMessage(dataBytes) + if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) + } + forceFlushFailedEvent = true + sawFailedEvent = true + } if trimmedData == "[DONE]" { sawDone = true } if openAIStreamEventIsTerminal(trimmedData) { sawTerminalEvent = true } - if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) + if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } @@ -3016,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } if !clientDisconnected { + if !clientOutputStarted && !lineStartsClientOutput { + pendingLines = append(pendingLines, line) + continue + } + if !clientOutputStarted && len(pendingLines) > 0 { + if !writePendingLines() { + continue + } + } if _, err := fmt.Fprintln(w, line); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) } else { + clientOutputStarted = true flusher.Flush() } } } if err := scanner.Err(); err != nil { - if sawTerminalEvent { + if sawTerminalEvent && !sawFailedEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } - if clientDisconnected { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + if sawFailedEvent { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) @@ -3038,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err } + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + msg := "OpenAI stream disconnected before completion" + if errText := strings.TrimSpace(err.Error()); errText != "" { + msg += ": " + errText + } + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) + } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", account.ID, @@ -3046,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } + if sawFailedEvent { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + } if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") + } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } @@ -3062,6 +3426,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( ctx context.Context, resp *http.Response, c *gin.Context, + originalModel string, + mappedModel string, ) (*OpenAIUsage, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { @@ -3073,7 +3439,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( // stream=false was requested. Without this conversion the client would // receive raw SSE text or a terminal event with empty output. if isEventStreamResponse(resp.Header) { - return s.handlePassthroughSSEToJSON(resp, c, body) + return s.handlePassthroughSSEToJSON(resp, c, body, originalModel, mappedModel) } usage := &OpenAIUsage{} @@ -3095,14 +3461,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( if contentType == "" { contentType = "application/json" } + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } c.Data(resp.StatusCode, contentType, body) return usage, nil } // handlePassthroughSSEToJSON converts an SSE response body into a JSON -// response for the passthrough path. It mirrors handleSSEToJSON but skips -// model replacement (passthrough does not remap models). -func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte) (*OpenAIUsage, error) { +// response for the passthrough path. It mirrors handleSSEToJSON while +// preserving passthrough payloads, except compact-only model remapping may +// rewrite model fields back to the original requested model. +func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -3121,6 +3491,9 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c } } body = finalResponse + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } // Correct tool calls in final response body = s.correctToolCallsInResponseBody(body) } else { @@ -3133,6 +3506,10 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) } usage = s.parseSSEUsageFromBody(bodyText) + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + } + body = []byte(bodyText) } writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) @@ -3631,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if keepaliveTicker != nil { keepaliveCh = keepaliveTicker.C } - // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 - lastDataAt := time.Now() + // Track downstream writes separately from upstream reads: pre-output failover + // can buffer response.created / response.in_progress, so keepalive must be + // based on downstream idle time. + lastDownstreamWriteAt := time.Now() // 仅发送一次错误事件,避免多次写入导致协议混乱。 // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; @@ -3640,6 +4019,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sawTerminalEvent := false + sawFailedEvent := false + failedMessage := "" + clientOutputStarted := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + var streamFailoverErr error sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -3656,7 +4040,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } if err := flushBuffered(); err != nil { clientDisconnected = true + return } + clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } needModelReplace := originalModel != mappedModel @@ -3664,45 +4051,73 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} } finalizeStream := func() (*openaiStreamingResult, error) { + if !sawTerminalEvent { + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + return resultWithUsage(), s.newOpenAIStreamFailoverError( + c, + account, + false, + upstreamRequestID, + nil, + "OpenAI stream ended before a terminal event", + ) + } + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) + } if !clientDisconnected { + hadBufferedData := bufferedWriter.Buffered() > 0 if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } else if hadBufferedData { + clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } - if !sawTerminalEvent { - return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") - } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } - if sawTerminalEvent { + if sawTerminalEvent && !sawFailedEvent { logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) return resultWithUsage(), nil, true } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true - } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) sendErrorEvent("response_too_large") return resultWithUsage(), scanErr, true } + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + msg := "OpenAI stream disconnected before completion" + if errText := strings.TrimSpace(scanErr.Error()); errText != "" { + msg += ": " + errText + } + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true + } sendErrorEvent("stream_read_error") return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true } processSSELine := func(line string, queueDrained bool) { - lastDataAt = time.Now() - + if streamFailoverErr != nil { + return + } // Extract data from SSE line (supports both "data: " and "data:" formats) if data, ok := extractOpenAISSEDataLine(line); ok { @@ -3716,18 +4131,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if openAIStreamEventIsTerminal(data) { sawTerminalEvent = true } + eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) + forceFlushFailedEvent := false + if eventType == "response.failed" { + failedMessage = extractOpenAISSEErrorMessage(dataBytes) + if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { + sawFailedEvent = true + streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage) + return + } + forceFlushFailedEvent = true + sawFailedEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { dataBytes = correctedData data = string(correctedData) line = "data: " + data + eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) } + startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType) // 写入客户端(客户端断开后继续 drain 上游) if !clientDisconnected { - shouldFlush := queueDrained - if firstTokenMs == nil && data != "" && data != "[DONE]" { + shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput) + if firstTokenMs == nil && startsClientOutput { // 保证首个 token 事件尽快出站,避免影响 TTFT。 shouldFlush = true } @@ -3741,12 +4170,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } else { + clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } } // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { + if firstTokenMs == nil && startsClientOutput { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } @@ -3762,10 +4194,13 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } else if _, err := bufferedWriter.WriteString("\n"); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else if queueDrained { + } else if queueDrained && clientOutputStarted { if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } else { + clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } } @@ -3776,6 +4211,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp defer putSSEScannerBuf64K(scanBuf) for scanner.Scan() { processSSELine(scanner.Text(), true) + if streamFailoverErr != nil { + return resultWithUsage(), streamFailoverErr + } } if result, err, done := handleScanErr(scanner.Err()); done { return result, err @@ -3825,6 +4263,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return result, err } processSSELine(ev.line, len(events) == 0) + if streamFailoverErr != nil { + return resultWithUsage(), streamFailoverErr + } case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -3846,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if clientDisconnected { continue } - if time.Since(lastDataAt) < keepaliveInterval { + if time.Since(lastDownstreamWriteAt) < keepaliveInterval { continue } if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { @@ -3857,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } else { + lastDownstreamWriteAt = time.Now() } } } @@ -3935,7 +4378,8 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag return } eventType := gjson.GetBytes(data, "type").String() - if eventType != "response.completed" && eventType != "response.done" { + if eventType != "response.completed" && eventType != "response.done" && + eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" { return } @@ -4082,7 +4526,7 @@ func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { } eventType := strings.TrimSpace(gjson.Get(data, "type").String()) switch eventType { - case "response.completed", "response.done", "response.failed": + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return eventType, []byte(data), true } } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ed7c78a3..bc900689 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -93,6 +93,13 @@ type cancelReadCloser struct{} func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } func (c cancelReadCloser) Close() error { return nil } +type errReadCloser struct { + err error +} + +func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err } +func (r errReadCloser) Close() error { return nil } + type failingGinWriter struct { gin.ResponseWriter failAfter int @@ -1003,6 +1010,190 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr } } +func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: errReadCloser{err: io.ErrUnexpectedEOF}, + Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.in_progress", + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-failed"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request") + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.in_progress", + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 1, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n")) + for i := 0; i < 6; i++ { + time.Sleep(250 * time.Millisecond) + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n")) + } + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), ":\n\n") + require.Contains(t, rec.Body.String(), "response.completed") +} + +func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.False(t, errors.As(err, &failoverErr)) + require.True(t, c.Writer.Written()) + require.Contains(t, rec.Body.String(), "response.failed") + require.Contains(t, rec.Body.String(), "high-risk cyber activity") +} + func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1072,7 +1263,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1104,16 +1295,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n")) }() - _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "") _ = pr.Close() if err == nil || !strings.Contains(err.Error(), "missing terminal event") { t.Fatalf("expected missing terminal event error, got %v", err) } } +func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}}, + } + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed") + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1139,7 +1366,42 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t _, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) }() - result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 2, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 1, result.usage.CacheReadInputTokens) +} + +func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "") _ = pr.Close() require.NoError(t, err) require.NotNil(t, result) diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 9bf3fba3..f332633c 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,5 +1,7 @@ package service +import "strings" + // resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible // forwarding. Group-level default mapping only applies when the account itself // did not match any explicit model_mapping rule. @@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } mappedModel, matched := account.ResolveMappedModel(requestedModel) - if !matched && defaultMappedModel != "" { + if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) { return defaultMappedModel } return mappedModel } + +func isExplicitCodexModel(model string) bool { + model = strings.TrimSpace(model) + if model == "" { + return false + } + if strings.Contains(model, "/") { + parts := strings.Split(model, "/") + model = parts[len(parts)-1] + } + model = strings.ToLower(strings.TrimSpace(model)) + if getNormalizedCodexModel(model) != "" { + return true + } + if strings.HasSuffix(model, "-openai-compact") { + base := strings.TrimSuffix(model, "-openai-compact") + return getNormalizedCodexModel(base) != "" + } + return false +} + +// resolveOpenAICompactForwardModel determines the compact-only upstream model +// for /responses/compact requests. It never affects normal /responses traffic. +// When no compact-specific mapping matches, the input model is returned as-is. +func resolveOpenAICompactForwardModel(account *Account, model string) string { + trimmedModel := strings.TrimSpace(model) + if trimmedModel == "" || account == nil { + return trimmedModel + } + + mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel) + if !matched { + return trimmedModel + } + if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" { + return trimmedMapped + } + return trimmedModel +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index f25863a8..4802c089 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) { account: &Account{ Credentials: map[string]any{}, }, - requestedModel: "gpt-5.4", + requestedModel: "claude-opus-4-6", defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-4o-mini", }, + { + name: "preserves explicit gpt-5.4 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, { name: "preserves exact passthrough mapping instead of group default", account: &Account{ @@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-5.4", }, + { + name: "preserves codex spark instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.3-codex-spark", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.3-codex-spark", + }, + { + name: "preserves gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.5", + }, + { + name: "preserves openai namespaced gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "openai/gpt-5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "openai/gpt-5.5", + }, + { + name: "preserves compact gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.5-openai-compact", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.5-openai-compact", + }, } for _, tt := range tests { @@ -85,6 +130,74 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t * } } +func TestResolveOpenAICompactForwardModel(t *testing.T) { + tests := []struct { + name string + account *Account + model string + expectedModel string + }{ + { + name: "nil account keeps original model", + account: nil, + model: "gpt-5.4", + expectedModel: "gpt-5.4", + }, + { + name: "missing compact mapping keeps original model", + account: &Account{ + Credentials: map[string]any{}, + }, + model: "gpt-5.4", + expectedModel: "gpt-5.4", + }, + { + name: "exact compact mapping overrides model", + account: &Account{ + Credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4-openai-compact", + }, + }, + }, + model: "gpt-5.4", + expectedModel: "gpt-5.4-openai-compact", + }, + { + name: "wildcard compact mapping overrides model", + account: &Account{ + Credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.*": "gpt-5-openai-compact", + }, + }, + }, + model: "gpt-5.4", + expectedModel: "gpt-5-openai-compact", + }, + { + name: "passthrough compact mapping remains unchanged", + account: &Account{ + Credentials: map[string]any{ + "compact_model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + model: "gpt-5.4", + expectedModel: "gpt-5.4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel { + t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel) + } + }) + } +} + func TestNormalizeCodexModel(t *testing.T) { cases := map[string]string{ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 69c9de42..049ffdd8 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te require.NoError(t, err) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) - require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent")) } func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index dea3c172..c0f98de4 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct { } // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 -// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 +// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、 // 或显式声明 tools/tool_choice。 func NeedsToolContinuation(reqBody map[string]any) bool { if reqBody == nil { @@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool { continue } itemType, _ := itemMap["type"].(string) - if itemType == "function_call_output" || itemType == "item_reference" { + if isCodexToolCallItemType(itemType) || itemType == "item_reference" { return true } } diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go index fe737ad6..3f415d9d 100644 --- a/backend/internal/service/openai_tool_continuation_test.go +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) { {name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true}, {name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false}, {name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true}, + {name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true}, + {name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true}, + {name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true}, {name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true}, {name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true}, {name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false}, diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index a5b97ca9..4005a921 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false) require.NoError(t, err) require.NotNil(t, selection) require.NotNil(t, selection.Account) @@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false) require.NoError(t, err) require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连") boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl") @@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false) require.NoError(t, err) require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连") boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl") @@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false) require.NoError(t, err) require.Nil(t, selection) } @@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false) require.NoError(t, err) require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连") } @@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky( require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour)) - selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil) + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false) require.NoError(t, err) require.NotNil(t, selection) require.NotNil(t, selection.Account) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 83849bf3..8c0222e2 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -3800,6 +3800,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( previousResponseID string, requestedModel string, excludedIDs map[int64]struct{}, + requireCompact bool, ) (*AccountSelectionResult, error) { if s == nil { return nil, nil @@ -3840,11 +3841,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } + // 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。 + if requireCompact && openAICompactSupportTier(account) == 0 { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 243edff3..c6167447 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -268,6 +269,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e switch action { case redeemActionSkipCompleted: + s.applyAffiliateRebateForOrder(ctx, o) // Code already created and redeemed — just mark completed return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") case redeemActionCreate: @@ -281,6 +283,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } + s.applyAffiliateRebateForOrder(ctx, o) return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") } @@ -358,6 +361,139 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action return c > 0 } +func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) { + if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 { + return + } + if s.affiliateService == nil { + return + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("begin affiliate rebate tx: %v", err), + }) + return + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + if !claimed { + return + } + + rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + + if rebateAmount <= 0 { + if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{ + "baseAmount": o.Amount, + "reason": "no inviter bound or rebate amount <= 0", + }); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + if err := tx.Commit(); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), + }) + } + return + } + + if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{ + "baseAmount": o.Amount, + "rebateAmount": rebateAmount, + }); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + + if err := tx.Commit(); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), + }) + } +} + +func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) { + if client == nil { + return false, errors.New("nil payment client") + } + oid := strconv.FormatInt(orderID, 10) + detail, _ := json.Marshal(map[string]any{ + "baseAmount": baseAmount, + "status": "reserved", + }) + rows, err := client.QueryContext(ctx, ` +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW() +WHERE NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs + WHERE order_id = $1 + AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') +) +ON CONFLICT (order_id, action) DO NOTHING +RETURNING id`, oid, string(detail)) + if err != nil { + return false, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return false, err + } + return false, nil + } + var claimID int64 + if err := rows.Scan(&claimID); err != nil { + return false, err + } + return true, nil +} + +func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error { + if client == nil { + return errors.New("nil payment client") + } + oid := strconv.FormatInt(orderID, 10) + detailJSON, _ := json.Marshal(detail) + updated, err := client.PaymentAuditLog.Update(). + Where( + paymentauditlog.OrderIDEQ(oid), + paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"), + ). + SetAction(action). + SetDetail(string(detailJSON)). + SetOperator("system"). + Save(ctx) + if err != nil { + return err + } + if updated == 0 { + return errors.New("affiliate rebate claim log not found") + } + return nil +} + func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) { now := time.Now() r := psErrMsg(cause) diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 97fd76a0..aa121e41 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -170,21 +170,22 @@ type TopUserStat struct { // --- Service --- type PaymentService struct { - providerMu sync.Mutex - providersLoaded bool - entClient *dbent.Client - registry *payment.Registry - loadBalancer payment.LoadBalancer - redeemService *RedeemService - subscriptionSvc *SubscriptionService - configService *PaymentConfigService - userRepo UserRepository - groupRepo GroupRepository - resumeService *PaymentResumeService + providerMu sync.Mutex + providersLoaded bool + entClient *dbent.Client + registry *payment.Registry + loadBalancer payment.LoadBalancer + redeemService *RedeemService + subscriptionSvc *SubscriptionService + configService *PaymentConfigService + userRepo UserRepository + groupRepo GroupRepository + resumeService *PaymentResumeService + affiliateService *AffiliateService } -func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} +func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService { + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService} svc.resumeService = psNewPaymentResumeService(configService) return svc } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 4730303f..9344de47 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间 // 返回 nil 表示无法从响应头中确定重置时间 -func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { +func calculateOpenAI429ResetTime(headers http.Header) *time.Time { snapshot := ParseCodexRateLimitHeaders(headers) if snapshot == nil { return nil @@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { + return calculateOpenAI429ResetTime(headers) +} + // anthropic429Result holds the parsed Anthropic 429 rate-limit information. type anthropic429Result struct { resetAt time.Time // The correct reset time to use for SetRateLimited diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index c79d8949..f871ee85 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "math" "net/url" "sort" "strconv" @@ -453,6 +454,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyChannelMonitorEnabled, SettingKeyChannelMonitorDefaultIntervalSeconds, SettingKeyAvailableChannelsEnabled, + SettingKeyAffiliateEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -540,6 +542,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]), AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", + + AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true", }, nil } @@ -686,6 +690,7 @@ type PublicSettingsInjectionPayload struct { ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` AvailableChannelsEnabled bool `json:"available_channels_enabled"` + AffiliateEnabled bool `json:"affiliate_enabled"` } // GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection. @@ -738,6 +743,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ChannelMonitorEnabled: settings.ChannelMonitorEnabled, ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, + AffiliateEnabled: settings.AffiliateEnabled, }, nil } @@ -1167,6 +1173,8 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) + updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { @@ -1202,6 +1210,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // Available channels feature switch updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled) + // Affiliate (邀请返利) feature switch + updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled) + // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion @@ -1477,6 +1488,30 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool { return value == "true" } +// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关) +func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool { + value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled) + if err != nil { + return false // 默认关闭 + } + return value == "true" +} + +// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。 +// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错, +// 调用方只关心一个可用的数值。 +func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate) + if err != nil { + return AffiliateRebateRateDefault + } + rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64) + if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) { + return AffiliateRebateRateDefault + } + return clampAffiliateRebateRate(rate) +} + // IsPasswordResetEnabled 检查是否启用密码重置功能 // 要求:必须同时开启邮件验证 func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { @@ -1719,6 +1754,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", @@ -1767,6 +1803,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // Available channels feature (default disabled; opt-in) SettingKeyAvailableChannelsEnabled: "false", + // Affiliate (邀请返利) feature (default disabled; opt-in) + SettingKeyAffiliateEnabled: "false", + // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", SettingKeyMaxClaudeCodeVersion: "", @@ -1846,6 +1885,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } + if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil { + result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate) + } else { + result.AffiliateRebateRate = AffiliateRebateRateDefault + } result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 @@ -2082,6 +2126,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Available channels feature (default: disabled; strict true) result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true" + // Affiliate (邀请返利) feature (default: disabled; strict true) + result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true" + // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] @@ -2130,6 +2177,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin return result } +func clampAffiliateRebateRate(value float64) float64 { + if math.IsNaN(value) || math.IsInf(value, 0) { + return AffiliateRebateRateDefault + } + if value < AffiliateRebateRateMin { + return AffiliateRebateRateMin + } + if value > AffiliateRebateRateMax { + return AffiliateRebateRateMax + } + return value +} + func isFalseSettingValue(value string) bool { switch strings.ToLower(strings.TrimSpace(value)) { case "false", "0", "off", "disabled": diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ddd4fff6..70d8efc3 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -106,6 +106,8 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + AffiliateEnabled bool + AffiliateRebateRate float64 DefaultUserRPMLimit int DefaultSubscriptions []DefaultSubscriptionSetting @@ -224,6 +226,9 @@ type PublicSettings struct { // Available Channels feature (user-facing aggregate view) AvailableChannelsEnabled bool `json:"available_channels_enabled"` + + // Affiliate (邀请返利) feature toggle + AffiliateEnabled bool `json:"affiliate_enabled"` } type WeChatConnectOAuthConfig struct { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index e57c2a80..8af5c693 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -489,6 +489,7 @@ var ProviderSet = wire.NewSet( NewGroupCapacityService, NewChannelService, NewModelPricingResolver, + NewAffiliateService, ProvidePaymentConfigService, NewPaymentService, ProvidePaymentOrderExpiryService, diff --git a/backend/migrations/130_add_user_affiliates.sql b/backend/migrations/130_add_user_affiliates.sql new file mode 100644 index 00000000..d8c001e0 --- /dev/null +++ b/backend/migrations/130_add_user_affiliates.sql @@ -0,0 +1,20 @@ +CREATE TABLE IF NOT EXISTS user_affiliates ( + user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + aff_code VARCHAR(32) NOT NULL UNIQUE, + inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + aff_count INTEGER NOT NULL DEFAULT 0, + aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0, + aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id); +CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota); + +COMMENT ON TABLE user_affiliates IS '用户邀请返利信息'; +COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码'; +COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID'; +COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数'; +COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额'; +COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额'; diff --git a/backend/migrations/131_affiliate_rebate_hardening.sql b/backend/migrations/131_affiliate_rebate_hardening.sql new file mode 100644 index 00000000..81e37a9e --- /dev/null +++ b/backend/migrations/131_affiliate_rebate_hardening.sql @@ -0,0 +1,58 @@ +-- 1) Normalize historical affiliate rebate rate values. +-- Legacy compatibility treated 0 20%). +-- We now use pure percentage semantics, so convert persisted fractional values once. +UPDATE settings +SET value = to_char((value::numeric * 100), 'FM999999990.########'), + updated_at = NOW() +WHERE key = 'affiliate_rebate_rate' + AND value ~ '^-?[0-9]+(\\.[0-9]+)?$' + AND value::numeric > 0 + AND value::numeric <= 1; + +-- 2) Affiliate ledger for accrual/transfer traceability. +CREATE TABLE IF NOT EXISTS user_affiliate_ledger ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + action VARCHAR(32) NOT NULL, + amount DECIMAL(20,8) NOT NULL, + source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id); +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action); + +COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)'; +COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer'; + +-- 3) Enforce idempotency at DB layer for payment audit actions. +WITH ranked AS ( + SELECT id, + ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn + FROM payment_audit_logs +) +DELETE FROM payment_audit_logs p +USING ranked r +WHERE p.id = r.id + AND r.rn > 1; + +CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq +ON payment_audit_logs(order_id, action); + +-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders. +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT po.id::text, + 'AFFILIATE_REBATE_SKIPPED', + '{"reason":"baseline before affiliate rebate idempotency rollout"}', + 'system', + NOW() +FROM payment_orders po +WHERE po.order_type = 'balance' + AND po.status = 'COMPLETED' + AND NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs pal + WHERE pal.order_id = po.id::text + AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') + ); diff --git a/backend/migrations/132_affiliate_custom_settings.sql b/backend/migrations/132_affiliate_custom_settings.sql new file mode 100644 index 00000000..840fe8e0 --- /dev/null +++ b/backend/migrations/132_affiliate_custom_settings.sql @@ -0,0 +1,16 @@ +-- 邀请返利:用户专属配置增强 +-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例(百分比,NULL 表示沿用全局比例) +-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选) + +ALTER TABLE user_affiliates + ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2); + +ALTER TABLE user_affiliates + ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false; + +CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings + ON user_affiliates (updated_at) + WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL; + +COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100,NULL 表示沿用全局)'; +COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)'; diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts new file mode 100644 index 00000000..22639bd2 --- /dev/null +++ b/frontend/src/api/admin/affiliates.ts @@ -0,0 +1,108 @@ +/** + * Admin Affiliate API endpoints + * Manage per-user affiliate (邀请返利) configurations: + * exclusive invite codes (overrides aff_code) and exclusive rebate rates. + */ + +import { apiClient } from '../client' +import type { PaginatedResponse } from '@/types' + +export interface AffiliateAdminEntry { + user_id: number + email: string + username: string + aff_code: string + aff_code_custom: boolean + aff_rebate_rate_percent?: number | null + aff_count: number +} + +export interface ListAffiliateUsersParams { + page?: number + page_size?: number + search?: string +} + +export interface UpdateAffiliateUserRequest { + aff_code?: string + aff_rebate_rate_percent?: number | null + /** Set true to explicitly clear the per-user rate (sets it to NULL). */ + clear_rebate_rate?: boolean +} + +export interface BatchSetRateRequest { + user_ids: number[] + aff_rebate_rate_percent?: number | null + /** Set true to clear rates instead of setting. */ + clear?: boolean +} + +export interface SimpleUser { + id: number + email: string + username: string +} + +export async function listUsers( + params: ListAffiliateUsersParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/users', + { + params: { + page: params.page ?? 1, + page_size: params.page_size ?? 20, + search: params.search ?? '', + }, + }, + ) + return data +} + +export async function lookupUsers(q: string): Promise { + const { data } = await apiClient.get( + '/admin/affiliates/users/lookup', + { params: { q } }, + ) + return data +} + +export async function updateUserSettings( + userId: number, + payload: UpdateAffiliateUserRequest, +): Promise<{ user_id: number }> { + const { data } = await apiClient.put<{ user_id: number }>( + `/admin/affiliates/users/${userId}`, + payload, + ) + return data +} + +export async function clearUserSettings( + userId: number, +): Promise<{ user_id: number }> { + const { data } = await apiClient.delete<{ user_id: number }>( + `/admin/affiliates/users/${userId}`, + ) + return data +} + +export async function batchSetRate( + payload: BatchSetRateRequest, +): Promise<{ affected: number }> { + const { data } = await apiClient.post<{ affected: number }>( + '/admin/affiliates/users/batch-rate', + payload, + ) + return data +} + +export const affiliatesAPI = { + listUsers, + lookupUsers, + updateUserSettings, + clearUserSettings, + batchSetRate, +} + +export default affiliatesAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index a2a313cb..639e3be2 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -30,6 +30,7 @@ import channelMonitorAPI from './channelMonitor' import channelMonitorTemplateAPI from './channelMonitorTemplate' import adminPaymentAPI from './payment' import windsurfAPI from './windsurf' +import affiliatesAPI from './affiliates' /** * Unified admin API object for convenient access @@ -61,7 +62,8 @@ export const adminAPI = { channelMonitor: channelMonitorAPI, channelMonitorTemplate: channelMonitorTemplateAPI, payment: adminPaymentAPI, - windsurf: windsurfAPI + windsurf: windsurfAPI, + affiliates: affiliatesAPI } export { @@ -91,7 +93,8 @@ export { channelMonitorAPI, channelMonitorTemplateAPI, adminPaymentAPI, - windsurfAPI + windsurfAPI, + affiliatesAPI } export default adminAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 32e48fc2..cf8626fc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -308,6 +308,7 @@ export interface SystemSettings { totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings default_balance: number; + affiliate_rebate_rate: number; default_concurrency: number; default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; @@ -478,6 +479,9 @@ export interface SystemSettings { // Available Channels feature switch available_channels_enabled: boolean; + + // Affiliate (邀请返利) feature switch + affiliate_enabled: boolean; } export interface UpdateSettingsRequest { @@ -490,6 +494,7 @@ export interface UpdateSettingsRequest { invitation_code_enabled?: boolean; totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; + affiliate_rebate_rate?: number; default_concurrency?: number; default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; @@ -636,6 +641,9 @@ export interface UpdateSettingsRequest { // Available Channels feature switch available_channels_enabled?: boolean; + + // Affiliate (邀请返利) feature switch + affiliate_enabled?: boolean; } /** diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index fd3cedb9..da7a91eb 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -9,7 +9,14 @@ import { prepareOAuthBindAccessTokenCookie, type WeChatOAuthPublicSettings, } from './auth' -import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types' +import type { + User, + ChangePasswordRequest, + NotifyEmailEntry, + UserAuthProvider, + UserAffiliateDetail, + AffiliateTransferResponse +} from '@/types' /** * Get current user profile @@ -168,6 +175,16 @@ export async function startOAuthBinding( window.location.href = startURL } +export async function getAffiliateDetail(): Promise { + const { data } = await apiClient.get('/user/aff') + return data +} + +export async function transferAffiliateQuota(): Promise { + const { data } = await apiClient.post('/user/aff/transfer') + return data +} + export const userAPI = { getProfile, updateProfile, @@ -180,7 +197,9 @@ export const userAPI = { bindEmailIdentity, unbindAuthIdentity, buildOAuthBindingStartURL, - startOAuthBinding + startOAuthBinding, + getAffiliateDetail, + transferAffiliateQuota } export default userAPI diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index ae0fd9a7..be9cba48 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -55,6 +55,17 @@ /> +
+ +