chore: merge upstream v0.1.119-121, keep Windsurf/Antigravity customizations
Upstream changes merged: - fix(scheduler): resolve SetSnapshot race conditions with Lua CAS script - fix: improve sticky session scheduling (debug logs + layer 1.5 checks) - feat: Anthropic cache TTL injection toggle - fix(gateway): stream EOF failover + sanitize stream errors - feat(httputil): zstd/gzip/deflate request decompression + bomb guard - feat(openai): OpenAI Fast/Flex Policy (HTTP + WebSocket + Admin) - feat(vertex): Vertex Service Account support - feat: account bulk edit scope and compact settings - feat(affiliate): rebate freeze migration - fix(openai): various fixes (passthrough fields, compact payload, etc.) Conflict resolutions: - domain/constants.go: keep both AccountTypeWindsurfSession + AccountTypeServiceAccount - scheduler_cache_unit_test.go: keep both test functions - gateway_service.go: remove dead code (claudeCodeUserAgentRe, isClaudeCodeRequest) - wire_gen.go: keep Windsurf service chain + add upstream claudeTokenProvider param - frontend/src/types/index.ts: keep windsurf + service_account types - frontend CreateAccountModal.vue: keep Windsurf login + Vertex service_account blocks - frontend PlatformTypeBadge.vue: keep both Session + Vertex cases - account_test_service.go: fix createTestPayload call to pass empty prompt arg
This commit is contained in:
commit
c5eb305f7f
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
docs/claude-relay-service/
|
||||
.codex
|
||||
|
||||
# ===================
|
||||
# Go 后端
|
||||
@ -121,7 +122,7 @@ scripts
|
||||
.code-review-state
|
||||
#openspec/
|
||||
code-reviews/
|
||||
#AGENTS.md
|
||||
AGENTS.md
|
||||
backend/cmd/server/server
|
||||
deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
|
||||
@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
|
||||
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
|
||||
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
|
||||
同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
|
||||
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
|
||||
エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
|
||||
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/pateway.png
Normal file
BIN
assets/partners/logos/pateway.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.0 KiB |
@ -1 +1 @@
|
||||
0.1.118
|
||||
0.1.121
|
||||
|
||||
@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
@ -143,6 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
@ -155,7 +156,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
|
||||
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
|
||||
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
@ -182,7 +183,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
@ -191,7 +191,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@ -33,6 +33,7 @@ const (
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeWindsurfSession = "windsurf-session" // Windsurf Session 类型账号(邮箱密码登录获取的 session token + api_key)
|
||||
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@ -99,7 +99,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@ -118,7 +118,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@ -135,19 +135,29 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Filters *BulkUpdateAccountFilters `json:"filters"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Group string `json:"group"`
|
||||
Search string `json:"search"`
|
||||
PrivacyMode string `json:"privacy_mode"`
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
@ -1370,6 +1380,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 && req.Filters == nil {
|
||||
response.BadRequest(c, "account_ids or filters is required")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
@ -1395,6 +1409,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
@ -1430,6 +1445,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
|
||||
if filters == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.BulkUpdateAccountFilters{
|
||||
Platform: filters.Platform,
|
||||
Type: filters.Type,
|
||||
Status: filters.Status,
|
||||
Group: filters.Group,
|
||||
Search: filters.Search,
|
||||
PrivacyMode: filters.PrivacyMode,
|
||||
}
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
|
||||
@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
|
||||
require.Equal(t, float64(2), data["success"])
|
||||
require.Equal(t, float64(0), data["failed"])
|
||||
}
|
||||
|
||||
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"filters": map[string]any{
|
||||
"platform": "openai",
|
||||
"type": "oauth",
|
||||
"status": "active",
|
||||
"group": "12",
|
||||
"privacy_mode": "blocked",
|
||||
"search": "bulk-target",
|
||||
},
|
||||
"schedulable": true,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
|
||||
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
|
||||
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
|
||||
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
|
||||
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
|
||||
t.Run("nil input returns nil", func(t *testing.T) {
|
||||
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
|
||||
})
|
||||
|
||||
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.NotNil(t, out)
|
||||
require.Len(t, out.Rules, 1)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, "all", out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: " ",
|
||||
Action: "pass",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "PRIORITY",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{
|
||||
{ServiceTier: "priority", Action: "filter", Scope: "all"},
|
||||
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
|
||||
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
|
||||
},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Len(t, out.Rules, 3)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
|
||||
})
|
||||
}
|
||||
|
||||
@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
|
||||
for i := range s.apiKeys {
|
||||
if s.apiKeys[i].ID == keyID {
|
||||
s.apiKeys[i].Usage5h = 0
|
||||
s.apiKeys[i].Usage1d = 0
|
||||
s.apiKeys[i].Usage7d = 0
|
||||
s.apiKeys[i].Window5hStart = nil
|
||||
s.apiKeys[i].Window1dStart = nil
|
||||
s.apiKeys[i].Window7dStart = nil
|
||||
k := s.apiKeys[i]
|
||||
return &k, nil
|
||||
}
|
||||
}
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
|
||||
}
|
||||
}
|
||||
|
||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
|
||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
|
||||
type AdminUpdateAPIKeyGroupRequest struct {
|
||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
|
||||
}
|
||||
|
||||
// UpdateGroup handles updating an API key's group binding
|
||||
// UpdateGroup handles updating an API key's admin-managed fields.
|
||||
// PUT /api/v1/admin/api-keys/:id
|
||||
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var resetKey *service.APIKey
|
||||
if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
|
||||
resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if resetKey != nil && req.GroupID == nil {
|
||||
result.APIKey = resetKey
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
APIKey *dto.APIKey `json:"api_key"`
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
|
||||
require.Nil(t, resp.Data.APIKey.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
|
||||
svc := newStubAdminService()
|
||||
now := time.Now()
|
||||
svc.apiKeys[0].Usage5h = 1.2
|
||||
svc.apiKeys[0].Usage1d = 3.4
|
||||
svc.apiKeys[0].Usage7d = 5.6
|
||||
svc.apiKeys[0].Window5hStart = &now
|
||||
svc.apiKeys[0].Window1dStart = &now
|
||||
svc.apiKeys[0].Window7dStart = &now
|
||||
router := setupAPIKeyHandler(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
APIKey struct {
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
} `json:"api_key"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Zero(t, resp.Data.APIKey.Usage5h)
|
||||
require.Zero(t, resp.Data.APIKey.Usage1d)
|
||||
require.Zero(t, resp.Data.APIKey.Usage7d)
|
||||
require.Nil(t, resp.Data.APIKey.Window5hStart)
|
||||
require.Nil(t, resp.Data.APIKey.Window1dStart)
|
||||
require.Nil(t, resp.Data.APIKey.Window7dStart)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
|
||||
svc := &failingUpdateGroupService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
|
||||
@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@ -206,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: settings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||
@ -245,9 +249,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
}
|
||||
|
||||
// OpenAI fast policy (stored under a dedicated setting key)
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
|
||||
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = dto.OpenAIFastPolicyRule(r)
|
||||
}
|
||||
return &dto.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
|
||||
//
|
||||
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
|
||||
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
|
||||
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
|
||||
// service.SetOpenAIFastPolicySettings 负责合法值校验。
|
||||
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = service.OpenAIFastPolicyRule(r)
|
||||
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
|
||||
if tier == "" {
|
||||
tier = service.OpenAIFastTierAny
|
||||
}
|
||||
rules[i].ServiceTier = tier
|
||||
}
|
||||
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
@ -342,6 +388,9 @@ type UpdateSettingsRequest struct {
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
@ -393,9 +442,10 @@ type UpdateSettingsRequest struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Payment visible method routing
|
||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||
@ -446,6 +496,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@ -485,6 +538,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if affiliateRebateRate > service.AffiliateRebateRateMax {
|
||||
affiliateRebateRate = service.AffiliateRebateRateMax
|
||||
}
|
||||
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
|
||||
if req.AffiliateRebateFreezeHours != nil {
|
||||
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
|
||||
}
|
||||
if affiliateRebateFreezeHours < 0 {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
|
||||
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
|
||||
if req.AffiliateRebateDurationDays != nil {
|
||||
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
|
||||
}
|
||||
if affiliateRebateDurationDays < 0 {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
|
||||
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
|
||||
}
|
||||
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
|
||||
if req.AffiliateRebatePerInviteeCap != nil {
|
||||
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
|
||||
}
|
||||
if affiliateRebatePerInviteeCap < 0 {
|
||||
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
|
||||
if req.TableDefaultPageSize <= 0 {
|
||||
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
|
||||
@ -1137,6 +1217,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@ -1192,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.EnableCCHSigning
|
||||
}(),
|
||||
EnableAnthropicCacheTTL1hInjection: func() bool {
|
||||
if req.EnableAnthropicCacheTTL1hInjection != nil {
|
||||
return *req.EnableAnthropicCacheTTL1hInjection
|
||||
}
|
||||
return previousSettings.EnableAnthropicCacheTTL1hInjection
|
||||
}(),
|
||||
PaymentVisibleMethodAlipaySource: func() string {
|
||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||
@ -1314,6 +1403,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update OpenAI fast policy (stored under dedicated key, only when provided).
|
||||
if req.OpenAIFastPolicySettings != nil {
|
||||
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update payment configuration (integrated into system settings).
|
||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||
@ -1458,6 +1555,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@ -1478,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||
@ -1516,6 +1617,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
}
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
@ -1768,6 +1874,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AffiliateRebateRate != after.AffiliateRebateRate {
|
||||
changed = append(changed, "affiliate_rebate_rate")
|
||||
}
|
||||
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
|
||||
changed = append(changed, "affiliate_rebate_freeze_hours")
|
||||
}
|
||||
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
|
||||
changed = append(changed, "affiliate_rebate_duration_days")
|
||||
}
|
||||
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
|
||||
changed = append(changed, "affiliate_rebate_per_invitee_cap")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
@ -1843,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EnableCCHSigning != after.EnableCCHSigning {
|
||||
changed = append(changed, "enable_cch_signing")
|
||||
}
|
||||
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
|
||||
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
|
||||
}
|
||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||
changed = append(changed, "payment_visible_method_alipay_source")
|
||||
}
|
||||
|
||||
@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
if s.values != nil {
|
||||
if value, ok := s.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
|
||||
@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
|
||||
VerifyCode string `json:"verify_code,omitempty"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
strings.TrimSpace(req.AffCode),
|
||||
); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
|
||||
@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
|
||||
|
||||
type completeOIDCOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
|
||||
|
||||
type completeWeChatOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AffCode string `json:"aff_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -106,11 +106,14 @@ type SystemSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@ -139,9 +142,10 @@ type SystemSettings struct {
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||
@ -195,6 +199,9 @@ type SystemSettings struct {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@ -291,6 +298,22 @@ type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
|
||||
type OpenAIFastPolicyRule struct {
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
ModelWhitelist []string `json:"model_whitelist,omitempty"`
|
||||
FallbackAction string `json:"fallback_action,omitempty"`
|
||||
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
|
||||
type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@ -288,6 +288,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// [DEBUG-STICKY] 打印会话 hash 生成结果
|
||||
reqLog.Info("sticky.session_hash_generated",
|
||||
zap.String("session_hash", sessionHash),
|
||||
zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
|
||||
)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
platform := ""
|
||||
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
@ -304,6 +310,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
// [DEBUG-STICKY] 打印粘性会话查询结果
|
||||
reqLog.Info("sticky.cache_lookup",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("bound_account_id", sessionBoundAccountID),
|
||||
)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
@ -312,6 +323,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
} else {
|
||||
reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
@ -591,6 +604,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
reqLog.Info("sticky.selecting_account",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("has_bound_session", hasBoundSession),
|
||||
zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
|
||||
)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
@ -624,6 +643,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// [DEBUG-STICKY] 打印账号选择结果
|
||||
reqLog.Info("sticky.account_selected",
|
||||
zap.Int64("selected_account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.Bool("slot_acquired", selection.Acquired),
|
||||
zap.Bool("has_wait_plan", selection.WaitPlan != nil),
|
||||
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
|
||||
zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
|
||||
)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
@ -690,6 +719,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
reqLog.Info("sticky.bind_after_wait",
|
||||
zap.String("session_key", sessionKey),
|
||||
zap.Int64("account_id", account.ID),
|
||||
)
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
@ -924,6 +957,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定粘性会话(成功转发后绑定/刷新)
|
||||
// - 无现有绑定(首次请求):创建绑定
|
||||
// - 选中账号与粘性账号一致:刷新 TTL
|
||||
// - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
|
||||
// 下次请求粘性账号恢复后仍可命中
|
||||
if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := ""
|
||||
if parsed.Multipart {
|
||||
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
||||
} else {
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
|
||||
@ -25,6 +25,7 @@ const (
|
||||
easypayStatusPaid = 1
|
||||
easypayHTTPTimeout = 10 * time.Second
|
||||
maxEasypayResponseSize = 1 << 20 // 1MB
|
||||
maxEasypayErrorSummary = 512
|
||||
tradeStatusSuccess = "TRADE_SUCCESS"
|
||||
signTypeMD5 = "MD5"
|
||||
paymentModePopup = "popup"
|
||||
@ -42,17 +43,55 @@ type EasyPay struct {
|
||||
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
|
||||
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
|
||||
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
|
||||
if config[k] == "" {
|
||||
if strings.TrimSpace(config[k]) == "" {
|
||||
return nil, fmt.Errorf("easypay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
cfg := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
cfg[k] = v
|
||||
}
|
||||
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
|
||||
return &EasyPay{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
config: cfg,
|
||||
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeEasyPayAPIBase(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
|
||||
}
|
||||
|
||||
func trimEasyPayEndpointPath(path string) string {
|
||||
path = strings.TrimRight(strings.TrimSpace(path), "/")
|
||||
lower := strings.ToLower(path)
|
||||
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
|
||||
if strings.HasSuffix(lower, endpoint) {
|
||||
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (e *EasyPay) apiBase() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeEasyPayAPIBase(e.config["apiBase"])
|
||||
}
|
||||
|
||||
func (e *EasyPay) Name() string { return "EasyPay" }
|
||||
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
|
||||
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
|
||||
@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
|
||||
for k, v := range params {
|
||||
q.Set(k, v)
|
||||
}
|
||||
base := strings.TrimRight(e.config["apiBase"], "/")
|
||||
payURL := base + "/submit.php?" + q.Encode()
|
||||
payURL := e.apiBase() + "/submit.php?" + q.Encode()
|
||||
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
|
||||
}
|
||||
|
||||
@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
|
||||
params["sign"] = easyPaySign(params, e.config["pkey"])
|
||||
params["sign_type"] = signTypeMD5
|
||||
|
||||
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay create: %w", err)
|
||||
}
|
||||
@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
|
||||
"act": "order", "pid": e.config["pid"],
|
||||
"key": e.config["pkey"], "out_trade_no": tradeNo,
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
|
||||
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay query: %w", err)
|
||||
}
|
||||
@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
|
||||
}
|
||||
|
||||
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"],
|
||||
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
|
||||
attempts := e.refundAttempts(req)
|
||||
if len(attempts) == 0 {
|
||||
return nil, fmt.Errorf("easypay refund missing order identifier")
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund: %w", err)
|
||||
var firstErr error
|
||||
for i, attempt := range attempts {
|
||||
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund request: %w", err)
|
||||
}
|
||||
if err := parseEasyPayRefundResponse(status, body); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
type easyPayRefundAttempt struct {
|
||||
params map[string]string
|
||||
refundID string
|
||||
}
|
||||
|
||||
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
|
||||
base := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
|
||||
}
|
||||
var attempts []easyPayRefundAttempt
|
||||
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["out_trade_no"] = orderID
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
|
||||
}
|
||||
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
|
||||
params := cloneStringMap(base)
|
||||
params["trade_no"] = tradeNo
|
||||
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isEasyPayRefundOrderNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
lower := strings.ToLower(msg)
|
||||
return strings.Contains(msg, "订单编号不存在") ||
|
||||
strings.Contains(msg, "订单不存在") ||
|
||||
strings.Contains(lower, "order not found") ||
|
||||
strings.Contains(lower, "not exist")
|
||||
}
|
||||
|
||||
func parseEasyPayRefundResponse(status int, body []byte) error {
|
||||
summary := summarizeEasyPayResponse(body)
|
||||
if status < http.StatusOK || status >= http.StatusMultipleChoices {
|
||||
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(trimmed)
|
||||
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
|
||||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Code any `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse refund: %w", err)
|
||||
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
|
||||
}
|
||||
if resp.Code != easypayCodeSuccess {
|
||||
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
|
||||
if !easyPayResponseCodeIsSuccess(resp.Code) {
|
||||
msg := strings.TrimSpace(resp.Msg)
|
||||
if msg == "" {
|
||||
msg = summary
|
||||
}
|
||||
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func easyPayResponseCodeIsSuccess(code any) bool {
|
||||
switch v := code.(type) {
|
||||
case float64:
|
||||
return int(v) == easypayCodeSuccess
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
return err == nil && n == easypayCodeSuccess
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeEasyPayResponse(body []byte) string {
|
||||
summary := strings.Join(strings.Fields(string(body)), " ")
|
||||
if summary == "" {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(summary) > maxEasypayErrorSummary {
|
||||
return summary[:maxEasypayErrorSummary] + "..."
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
}
|
||||
|
||||
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
|
||||
body, _, err := e.postRaw(ctx, endpoint, params)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
client := e.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: easypayHTTPTimeout}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func easyPaySign(params map[string]string, pkey string) string {
|
||||
|
||||
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
196
backend/internal/payment/provider/easypay_refund_test.go
Normal file
@ -0,0 +1,196 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestNormalizeEasyPayAPIBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
|
||||
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
|
||||
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotPath string
|
||||
var gotQuery url.Values
|
||||
var gotForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotQuery = r.URL.Query()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForm = r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
|
||||
t.Fatalf("Refund response = %+v, want success", resp)
|
||||
}
|
||||
if gotPath != "/api.php" {
|
||||
t.Fatalf("refund path = %q, want /api.php", gotPath)
|
||||
}
|
||||
if gotQuery.Get("act") != "refund" {
|
||||
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
|
||||
}
|
||||
for key, want := range map[string]string{
|
||||
"pid": "pid-1",
|
||||
"key": "pkey-1",
|
||||
"out_trade_no": "out-456",
|
||||
"money": "1.50",
|
||||
} {
|
||||
if got := gotForm.Get(key); got != want {
|
||||
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
|
||||
}
|
||||
}
|
||||
if got := gotForm.Get("trade_no"); got != "" {
|
||||
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotForms []url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api.php" {
|
||||
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("act") != "refund" {
|
||||
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm: %v", err)
|
||||
}
|
||||
gotForms = append(gotForms, r.PostForm)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if len(gotForms) == 1 {
|
||||
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL+"/mapi.php")
|
||||
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
TradeNo: "trade-123",
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Refund returned error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
|
||||
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
|
||||
}
|
||||
if len(gotForms) != 2 {
|
||||
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
|
||||
}
|
||||
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
|
||||
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[0].Get("trade_no"); got != "" {
|
||||
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
|
||||
}
|
||||
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
|
||||
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
|
||||
}
|
||||
if got := gotForms[1].Get("out_trade_no"); got != "" {
|
||||
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayRefundResponseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
want string
|
||||
}{
|
||||
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
|
||||
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
|
||||
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
|
||||
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, _ = w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestEasyPay(t, server.URL)
|
||||
_, err := provider.Refund(context.Background(), payment.RefundRequest{
|
||||
OrderID: "out-456",
|
||||
Amount: "1.50",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Refund returned nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
|
||||
t.Helper()
|
||||
|
||||
provider, err := NewEasyPay("test-instance", map[string]string{
|
||||
"pid": "pid-1",
|
||||
"pkey": "pkey-1",
|
||||
"apiBase": apiBase,
|
||||
"notifyUrl": "https://example.com/notify",
|
||||
"returnUrl": "https://example.com/return",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewEasyPay: %v", err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Cached response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 3318, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cached_clamp",
|
||||
Model: "gpt-5.2",
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 5,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 150,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
|
||||
assert.Equal(t, 0, anth.Usage.InputTokens)
|
||||
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 5, anth.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "tool_use", anth.Content[1].Type)
|
||||
assert.Equal(t, "call_1", anth.Content[1].ID)
|
||||
assert.Equal(t, "get_weather", anth.Content[1].Name)
|
||||
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_read",
|
||||
Model: "gpt-5.5",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_read",
|
||||
Name: "Read",
|
||||
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.Equal(t, "tool_use", anth.Content[0].Type)
|
||||
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_other",
|
||||
Model: "gpt-5.5",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_other",
|
||||
Name: "Search",
|
||||
Arguments: `{"query":""}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
|
||||
require.Len(t, anth.Content, 1)
|
||||
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
|
||||
@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) {
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 54006,
|
||||
OutputTokens: 123,
|
||||
TotalTokens: 54129,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 50688,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
assert.Equal(t, 3318, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 123, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingToolCall(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
|
||||
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
|
||||
}, state)
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
|
||||
}, state)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 0,
|
||||
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
}, state)
|
||||
assert.Len(t, events, 0)
|
||||
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.done",
|
||||
OutputIndex: 0,
|
||||
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
|
||||
}, state)
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "content_block_delta", events[0].Type)
|
||||
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
|
||||
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
|
||||
assert.Equal(t, "content_block_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestStreamingReasoning(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
|
||||
@ -835,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
assert.NotContains(t, tc, "function")
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
|
||||
req := &ResponsesRequest{
|
||||
Model: "gpt-5.2",
|
||||
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ResponsesToAnthropicRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc map[string]string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "tool", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
}
|
||||
|
||||
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
|
||||
req := &ResponsesRequest{
|
||||
Model: "gpt-5.2",
|
||||
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
|
||||
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
|
||||
}
|
||||
|
||||
resp, err := ResponsesToAnthropicRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var tc map[string]string
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "tool", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
// {"type":"auto"} → "auto"
|
||||
// {"type":"any"} → "required"
|
||||
// {"type":"none"} → "none"
|
||||
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
|
||||
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
|
||||
return json.Marshal("none")
|
||||
case "tool":
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": tc.Name},
|
||||
"type": "function",
|
||||
"name": tc.Name,
|
||||
})
|
||||
default:
|
||||
// Pass through unknown types as-is
|
||||
|
||||
@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
assert.Equal(t, "get_weather", tc["name"])
|
||||
assert.NotContains(t, tc, "function")
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
|
||||
@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
// {"name":"X"} → {"type":"function","name":"X"}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
"type": "function",
|
||||
"name": obj.Name,
|
||||
})
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallID(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: json.RawMessage(item.Arguments),
|
||||
Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
|
||||
})
|
||||
case "web_search_call":
|
||||
toolUseID := "srvtoolu_" + item.ID
|
||||
@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
|
||||
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
|
||||
|
||||
if resp.Usage != nil {
|
||||
out.Usage = AnthropicUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
|
||||
if usage == nil {
|
||||
return AnthropicUsage{}
|
||||
}
|
||||
|
||||
cachedTokens := 0
|
||||
if usage.InputTokensDetails != nil {
|
||||
cachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens - cachedTokens
|
||||
if inputTokens < 0 {
|
||||
inputTokens = 0
|
||||
}
|
||||
|
||||
return AnthropicUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
|
||||
if name != "Read" || raw == "" {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
var input map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(raw), &input); err != nil {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
|
||||
delete(input, "pages")
|
||||
sanitized, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
|
||||
ContentBlockIndex int
|
||||
ContentBlockOpen bool
|
||||
CurrentBlockType string // "text" | "thinking" | "tool_use"
|
||||
CurrentToolName string
|
||||
CurrentToolArgs string
|
||||
|
||||
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
|
||||
OutputIndexToBlockIdx map[int]int
|
||||
@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToAnthHandleFuncArgsDelta(evt, state)
|
||||
case "response.function_call_arguments.done":
|
||||
return resToAnthHandleBlockDone(state)
|
||||
return resToAnthHandleFuncArgsDone(evt, state)
|
||||
case "response.output_item.done":
|
||||
return resToAnthHandleOutputItemDone(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
|
||||
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
|
||||
state.ContentBlockOpen = true
|
||||
state.CurrentBlockType = "tool_use"
|
||||
state.CurrentToolName = evt.Item.Name
|
||||
state.CurrentToolArgs = ""
|
||||
|
||||
events = append(events, AnthropicStreamEvent{
|
||||
Type: "content_block_start",
|
||||
@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
return nil
|
||||
}
|
||||
|
||||
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
|
||||
state.CurrentToolArgs += evt.Delta
|
||||
return nil
|
||||
}
|
||||
|
||||
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
|
||||
}}
|
||||
}
|
||||
|
||||
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
|
||||
return resToAnthHandleBlockDone(state)
|
||||
}
|
||||
|
||||
raw := evt.Arguments
|
||||
if raw == "" {
|
||||
raw = state.CurrentToolArgs
|
||||
}
|
||||
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
|
||||
if len(sanitized) == 0 {
|
||||
return closeCurrentBlock(state)
|
||||
}
|
||||
|
||||
idx := state.ContentBlockIndex
|
||||
events := []AnthropicStreamEvent{{
|
||||
Type: "content_block_delta",
|
||||
Index: &idx,
|
||||
Delta: &AnthropicDelta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(sanitized),
|
||||
},
|
||||
}}
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
return events
|
||||
}
|
||||
|
||||
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
stopReason := "end_turn"
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
state.InputTokens = evt.Response.Usage.InputTokens
|
||||
state.OutputTokens = evt.Response.Usage.OutputTokens
|
||||
if evt.Response.Usage.InputTokensDetails != nil {
|
||||
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
|
||||
state.InputTokens = usage.InputTokens
|
||||
state.OutputTokens = usage.OutputTokens
|
||||
state.CacheReadInputTokens = usage.CacheReadInputTokens
|
||||
}
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
|
||||
idx := state.ContentBlockIndex
|
||||
state.ContentBlockOpen = false
|
||||
state.ContentBlockIndex++
|
||||
state.CurrentToolName = ""
|
||||
state.CurrentToolArgs = ""
|
||||
return []AnthropicStreamEvent{{
|
||||
Type: "content_block_stop",
|
||||
Index: &idx,
|
||||
|
||||
@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
|
||||
// "auto" → {"type":"auto"}
|
||||
// "required" → {"type":"any"}
|
||||
// "none" → {"type":"none"}
|
||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
|
||||
// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
|
||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
|
||||
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try as string first
|
||||
var s string
|
||||
@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
|
||||
// Try as object with type=function
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
|
||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
|
||||
name := strings.TrimSpace(tc.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(tc.Function.Name)
|
||||
}
|
||||
if name == "" {
|
||||
return raw, nil
|
||||
}
|
||||
return json.Marshal(map[string]string{
|
||||
"type": "tool",
|
||||
"name": tc.Function.Name,
|
||||
"name": name,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -2,16 +2,28 @@ package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
const (
|
||||
requestBodyReadInitCap = 512
|
||||
requestBodyReadMaxInitCap = 1 << 20
|
||||
// maxDecompressedBodySize limits the decompressed request body to 64 MB
|
||||
// to prevent decompression bomb attacks.
|
||||
maxDecompressedBodySize = 64 << 20
|
||||
)
|
||||
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
|
||||
// on content length, transparently decoding any Content-Encoding the upstream
|
||||
// client used to compress the body (zstd, gzip, deflate).
|
||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
raw := buf.Bytes()
|
||||
|
||||
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
|
||||
if enc == "" || enc == "identity" {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
decoded, err := decompressRequestBody(enc, raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
|
||||
}
|
||||
|
||||
req.Header.Del("Content-Encoding")
|
||||
req.Header.Del("Content-Length")
|
||||
req.ContentLength = int64(len(decoded))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
|
||||
switch encoding {
|
||||
case "zstd":
|
||||
dec, err := zstd.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dec.Close()
|
||||
return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
|
||||
case "gzip", "x-gzip":
|
||||
gr, err := gzip.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = gr.Close() }()
|
||||
return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
|
||||
case "deflate":
|
||||
zr, err := zlib.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = zr.Close() }()
|
||||
return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
|
||||
default:
|
||||
return nil, errors.New("unsupported Content-Encoding")
|
||||
}
|
||||
}
|
||||
|
||||
143
backend/internal/pkg/httputil/body_test.go
Normal file
143
backend/internal/pkg/httputil/body_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
|
||||
|
||||
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
|
||||
t.Helper()
|
||||
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
if encoding != "" {
|
||||
req.Header.Set("Content-Encoding", encoding)
|
||||
}
|
||||
req.ContentLength = int64(len(body))
|
||||
return req
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
|
||||
enc, _ := zstd.NewWriter(nil)
|
||||
compressed := enc.EncodeAll([]byte(samplePayload), nil)
|
||||
_ = enc.Close()
|
||||
|
||||
req := newRequestWithBody(t, compressed, "zstd")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
if req.Header.Get("Content-Encoding") != "" {
|
||||
t.Fatalf("Content-Encoding should be cleared after decoding")
|
||||
}
|
||||
if req.ContentLength != int64(len(samplePayload)) {
|
||||
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
if _, err := gw.Write([]byte(samplePayload)); err != nil {
|
||||
t.Fatalf("gzip write: %v", err)
|
||||
}
|
||||
if err := gw.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
|
||||
req := newRequestWithBody(t, buf.Bytes(), "gzip")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
zw := zlib.NewWriter(&buf)
|
||||
if _, err := zw.Write([]byte(samplePayload)); err != nil {
|
||||
t.Fatalf("zlib write: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("zlib close: %v", err)
|
||||
}
|
||||
|
||||
req := newRequestWithBody(t, buf.Bytes(), "deflate")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "br")
|
||||
_, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported encoding, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "br") {
|
||||
t.Fatalf("error should mention encoding, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
|
||||
_, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for corrupt zstd body, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil body, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
|
||||
req := newRequestWithBody(t, []byte(samplePayload), "identity")
|
||||
got, err := ReadRequestBodyWithPrealloc(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != samplePayload {
|
||||
t.Fatalf("body mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var applied bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
res, err := txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
|
||||
amount, inviterID,
|
||||
)
|
||||
// freezeHours > 0: add to frozen quota; == 0: add to available quota directly
|
||||
var updateSQL string
|
||||
if freezeHours > 0 {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
} else {
|
||||
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
|
||||
}
|
||||
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -106,10 +110,19 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
if freezeHours > 0 {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
|
||||
inviterID, amount, inviteeUserID, freezeHours); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
} else {
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
applied = true
|
||||
@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID);
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx,
|
||||
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
|
||||
inviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
var total float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return total, rows.Close()
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
|
||||
var thawed float64
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
var err error
|
||||
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
|
||||
return err
|
||||
})
|
||||
return thawed, err
|
||||
}
|
||||
|
||||
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
|
||||
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH matured AS (
|
||||
UPDATE user_affiliate_ledger
|
||||
SET frozen_until = NULL, updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND frozen_until IS NOT NULL
|
||||
AND frozen_until <= NOW()
|
||||
RETURNING amount
|
||||
)
|
||||
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("thaw frozen quota: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var thawed float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&thawed); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if thawed <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err = txClient.ExecContext(txCtx, `
|
||||
UPDATE user_affiliates
|
||||
SET aff_quota = aff_quota + $1,
|
||||
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $2`, thawed, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("move thawed quota: %w", err)
|
||||
}
|
||||
return thawed, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
var transferred float64
|
||||
var newBalance float64
|
||||
@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
// Thaw any matured frozen quota before transfer.
|
||||
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
|
||||
return fmt.Errorf("thaw before transfer: %w", err)
|
||||
}
|
||||
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH claimed AS (
|
||||
SELECT aff_quota::double precision AS amount
|
||||
@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64,
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.created_at
|
||||
ua.created_at,
|
||||
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
|
||||
FROM user_affiliates ua
|
||||
LEFT JOIN users u ON u.id = ua.user_id
|
||||
LEFT JOIN user_affiliate_ledger ual
|
||||
ON ual.user_id = $1
|
||||
AND ual.source_user_id = ua.user_id
|
||||
AND ual.action = 'accrue'
|
||||
WHERE ua.inviter_id = $1
|
||||
GROUP BY ua.user_id, u.email, u.username, ua.created_at
|
||||
ORDER BY ua.created_at DESC
|
||||
LIMIT $2`, inviterID, limit)
|
||||
if err != nil {
|
||||
@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInvitee
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.CreatedAt = &createdAt
|
||||
@ -299,6 +393,7 @@ SELECT user_id,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
@ -326,6 +421,7 @@ WHERE user_id = $1`, userID)
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
@ -351,6 +447,7 @@ SELECT user_id,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_frozen_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffFrozenQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
|
||||
@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, bound, "invitee must bind to inviter")
|
||||
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
|
||||
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
|
||||
require.NoError(t, err)
|
||||
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||
|
||||
|
||||
@ -24,6 +24,49 @@ const (
|
||||
|
||||
defaultSchedulerSnapshotMGetChunkSize = 128
|
||||
defaultSchedulerSnapshotWriteChunkSize = 256
|
||||
|
||||
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
|
||||
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
|
||||
snapshotGraceTTLSeconds = 60
|
||||
)
|
||||
|
||||
var (
|
||||
// activateSnapshotScript 原子 CAS 切换快照版本。
|
||||
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
|
||||
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
|
||||
//
|
||||
// KEYS[1] = activeKey (sched:active:{bucket})
|
||||
// KEYS[2] = readyKey (sched:ready:{bucket})
|
||||
// KEYS[3] = bucketSetKey (sched:buckets)
|
||||
// KEYS[4] = snapshotKey (新写入的快照 key)
|
||||
// ARGV[1] = 新版本号字符串
|
||||
// ARGV[2] = bucket 字符串 (用于 SADD)
|
||||
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
|
||||
// ARGV[4] = 宽限期 TTL 秒数
|
||||
//
|
||||
// 返回 1 = 已激活, 0 = 版本过旧未激活
|
||||
activateSnapshotScript = redis.NewScript(`
|
||||
local currentActive = redis.call('GET', KEYS[1])
|
||||
local newVersion = tonumber(ARGV[1])
|
||||
|
||||
if currentActive ~= false then
|
||||
local curVersion = tonumber(currentActive)
|
||||
if curVersion and newVersion < curVersion then
|
||||
redis.call('DEL', KEYS[4])
|
||||
return 0
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('SET', KEYS[1], ARGV[1])
|
||||
redis.call('SET', KEYS[2], '1')
|
||||
redis.call('SADD', KEYS[3], ARGV[2])
|
||||
|
||||
if currentActive ~= false and currentActive ~= ARGV[1] then
|
||||
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type schedulerCache struct {
|
||||
@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||
|
||||
// Phase 1: 分配新版本号并写入快照数据。
|
||||
// INCR 保证每个调用方获得唯一递增版本号。
|
||||
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
|
||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||
if err != nil {
|
||||
@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
if len(accounts) > 0 {
|
||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||
members := make([]redis.Z, 0, len(accounts))
|
||||
@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
Member: strconv.FormatInt(account.ID, 10),
|
||||
})
|
||||
}
|
||||
pipe := c.rdb.Pipeline()
|
||||
for start := 0; start < len(members); start += c.writeChunkSize {
|
||||
end := start + c.writeChunkSize
|
||||
if end > len(members) {
|
||||
@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
}
|
||||
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
|
||||
}
|
||||
} else {
|
||||
pipe.Del(ctx, snapshotKey)
|
||||
}
|
||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if oldActive != "" && oldActive != versionStr {
|
||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||
// Phase 2: 原子 CAS 激活版本。
|
||||
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
|
||||
// 防止并发写入导致版本回滚。
|
||||
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||
|
||||
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
|
||||
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
|
||||
|
||||
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
|
||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||
if err != nil {
|
||||
@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
|
||||
SessionWindowStart: account.SessionWindowStart,
|
||||
SessionWindowEnd: account.SessionWindowEnd,
|
||||
SessionWindowStatus: account.SessionWindowStatus,
|
||||
AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
|
||||
GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
|
||||
Credentials: filterSchedulerCredentials(account.Credentials),
|
||||
Extra: filterSchedulerExtra(account.Extra),
|
||||
}
|
||||
}
|
||||
|
||||
func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
|
||||
if len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]service.AccountGroup, 0, len(accountGroups))
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, service.AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
|
||||
if len(groupIDs) == 0 && len(accountGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
|
||||
filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
for _, ag := range accountGroups {
|
||||
if ag.GroupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[ag.GroupID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[ag.GroupID] = struct{}{}
|
||||
filtered = append(filtered, ag.GroupID)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
|
||||
if len(credentials) == 0 {
|
||||
return nil
|
||||
|
||||
@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
SessionWindowStart: &now,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
SessionWindowStatus: "active",
|
||||
GroupIDs: []int64{bucket.GroupID},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 101,
|
||||
GroupID: bucket.GroupID,
|
||||
Priority: 5,
|
||||
Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
|
||||
@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
|
||||
require.Equal(t, 4, got.GetMaxSessions())
|
||||
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 1)
|
||||
require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
|
||||
full, err := cache.GetAccount(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, full)
|
||||
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
|
||||
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
|
||||
require.Len(t, full.AccountGroups, 1)
|
||||
require.NotNil(t, full.AccountGroups[0].Group)
|
||||
}
|
||||
|
||||
@ -56,3 +56,43 @@ func TestBuildSchedulerMetadataAccount_KeepsModelRateLimits(t *testing.T) {
|
||||
require.Equal(t, modelLimits, got.Extra["model_rate_limits"], "model_rate_limits must be carried into scheduler snapshot for rate-limit-aware selection")
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
}
|
||||
|
||||
func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
|
||||
account := service.Account{
|
||||
ID: 42,
|
||||
Platform: service.PlatformAnthropic,
|
||||
GroupIDs: []int64{7, 9, 7, 0},
|
||||
AccountGroups: []service.AccountGroup{
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 7,
|
||||
Priority: 2,
|
||||
Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
|
||||
Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 11,
|
||||
Priority: 3,
|
||||
Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
|
||||
},
|
||||
{
|
||||
AccountID: 42,
|
||||
GroupID: 0,
|
||||
Priority: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := buildSchedulerMetadataAccount(account)
|
||||
|
||||
require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
|
||||
require.Len(t, got.AccountGroups, 2)
|
||||
require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
|
||||
require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
|
||||
require.Equal(t, 2, got.AccountGroups[0].Priority)
|
||||
require.Nil(t, got.AccountGroups[0].Account)
|
||||
require.Nil(t, got.AccountGroups[0].Group)
|
||||
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
|
||||
require.Nil(t, got.Groups)
|
||||
}
|
||||
|
||||
@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@ -737,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
@ -745,6 +749,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": true,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": true,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": [],
|
||||
"payment_enabled": false,
|
||||
@ -898,6 +912,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"default_concurrency": 0,
|
||||
"default_balance": 0,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
"affiliate_rebate_per_invitee_cap": 0,
|
||||
"default_user_rpm_limit": 0,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
@ -918,12 +935,23 @@ func TestAPIContracts(t *testing.T) {
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"enable_cch_signing": false,
|
||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||
"web_search_emulation_enabled": false,
|
||||
"payment_visible_method_alipay_source": "",
|
||||
"payment_visible_method_wxpay_source": "",
|
||||
"payment_visible_method_alipay_enabled": false,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": false,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"payment_enabled": false,
|
||||
"payment_min_amount": 0,
|
||||
"payment_max_amount": 0,
|
||||
|
||||
@ -66,6 +66,7 @@ func isOpenAIImageModel(model string) bool {
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
windsurfChatService *WindsurfChatService
|
||||
httpUpstream HTTPUpstream
|
||||
@ -77,6 +78,7 @@ type AccountTestService struct {
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
windsurfChatService *WindsurfChatService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@ -86,6 +88,7 @@ func NewAccountTestService(
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
windsurfChatService: windsurfChatService,
|
||||
httpUpstream: httpUpstream,
|
||||
@ -223,6 +226,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if account.IsBedrock() {
|
||||
return s.testBedrockAccountConnection(c, ctx, account, testModelID, prompt)
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
@ -326,6 +332,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
|
||||
testModelID = mappedModel
|
||||
} else {
|
||||
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payload, err := createTestPayload(testModelID, "")
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
|
||||
}
|
||||
|
||||
if s.claudeTokenProvider == nil {
|
||||
return s.sendErrorAndEnd(c, "Claude token provider not configured")
|
||||
}
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
|
||||
}
|
||||
|
||||
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errMsg)
|
||||
}
|
||||
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string, prompt string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
@ -728,8 +802,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
testModelID = geminicli.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
// For static upstream credentials with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@ -757,6 +831,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeServiceAccount:
|
||||
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@ -948,6 +1024,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
||||
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service account access token: %w", err)
|
||||
}
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
|
||||
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
|
||||
var inner map[string]any
|
||||
@ -1286,7 +1383,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
|
||||
apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
|
||||
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 54,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
|
||||
require.Contains(t, rec.Body.String(), "\"success\":true")
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -58,6 +59,7 @@ type AdminService interface {
|
||||
|
||||
// API Key management (admin)
|
||||
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
|
||||
AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
|
||||
|
||||
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
|
||||
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
|
||||
@ -291,6 +293,7 @@ type UpdateAccountInput struct {
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Filters *BulkUpdateAccountFilters
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
@ -307,6 +310,15 @@ type BulkUpdateAccountsInput struct {
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string
|
||||
Type string
|
||||
Status string
|
||||
Group string
|
||||
Search string
|
||||
PrivacyMode string
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
type BulkUpdateAccountResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
@ -1961,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
|
||||
func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apiKey.Usage5h = 0
|
||||
apiKey.Usage1d = 0
|
||||
apiKey.Usage7d = 0
|
||||
apiKey.Window5hStart = nil
|
||||
apiKey.Window1dStart = nil
|
||||
apiKey.Window7dStart = nil
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
if s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// ReplaceUserGroup 替换用户的专属分组
|
||||
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
|
||||
if oldGroupID == newGroupID {
|
||||
@ -2286,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
if len(input.AccountIDs) == 0 && input.Filters != nil {
|
||||
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input.AccountIDs = accountIDs
|
||||
}
|
||||
|
||||
result := &BulkUpdateAccountsResult{
|
||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
@ -2401,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
|
||||
if filters == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
groupID := int64(0)
|
||||
switch strings.TrimSpace(filters.Group) {
|
||||
case "":
|
||||
case "ungrouped":
|
||||
groupID = AccountListGroupUngrouped
|
||||
default:
|
||||
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group filter: %w", err)
|
||||
}
|
||||
groupID = parsedGroupID
|
||||
}
|
||||
|
||||
const pageSize = 500
|
||||
page := 1
|
||||
accountIDs := make([]int64, 0, pageSize)
|
||||
|
||||
for {
|
||||
accounts, total, err := s.ListAccounts(
|
||||
ctx,
|
||||
page,
|
||||
pageSize,
|
||||
filters.Platform,
|
||||
filters.Type,
|
||||
filters.Status,
|
||||
filters.Search,
|
||||
groupID,
|
||||
filters.PrivacyMode,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountIDs = append(accountIDs, account.ID)
|
||||
}
|
||||
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
|
||||
return accountIDs, nil
|
||||
}
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
|
||||
@ -5,8 +5,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
|
||||
getByIDCalled []int64
|
||||
listByGroupData map[int64][]Account
|
||||
listByGroupErr map[int64]error
|
||||
listData []Account
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
listCalled bool
|
||||
lastListParams pagination.PaginationParams
|
||||
lastListFilters struct {
|
||||
platform string
|
||||
accountType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
privacyMode string
|
||||
}
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listCalled = true
|
||||
s.lastListParams = params
|
||||
s.lastListFilters.platform = platform
|
||||
s.lastListFilters.accountType = accountType
|
||||
s.lastListFilters.status = status
|
||||
s.lastListFilters.search = search
|
||||
s.lastListFilters.groupID = groupID
|
||||
s.lastListFilters.privacyMode = privacyMode
|
||||
if s.listErr != nil {
|
||||
return nil, nil, s.listErr
|
||||
}
|
||||
if s.listResult != nil {
|
||||
return s.listData, s.listResult, nil
|
||||
}
|
||||
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
|
||||
// No BindGroups should have been called since the check runs before any write.
|
||||
require.Empty(t, repo.bindGroupsCalls)
|
||||
}
|
||||
|
||||
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
listData: []Account{
|
||||
{ID: 7},
|
||||
{ID: 11},
|
||||
},
|
||||
listResult: &pagination.PaginationResult{Total: 2},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
|
||||
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
|
||||
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
|
||||
|
||||
filtersValue := reflect.New(filtersField.Type().Elem())
|
||||
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
|
||||
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
|
||||
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
|
||||
filtersValue.Elem().FieldByName("Group").SetString("12")
|
||||
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
|
||||
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
|
||||
filtersField.Set(filtersValue)
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
|
||||
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
|
||||
require.Equal(t, StatusActive, repo.lastListFilters.status)
|
||||
require.Equal(t, "bulk-target", repo.lastListFilters.search)
|
||||
require.Equal(t, int64(12), repo.lastListFilters.groupID)
|
||||
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
|
||||
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
|
||||
}
|
||||
|
||||
@ -65,16 +65,18 @@ type AffiliateSummary struct {
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AffiliateInvitee struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt *time.Time `json:"created_at,omitempty"`
|
||||
TotalRebate float64 `json:"total_rebate"`
|
||||
}
|
||||
|
||||
type AffiliateDetail struct {
|
||||
@ -83,6 +85,7 @@ type AffiliateDetail struct {
|
||||
InviterID *int64 `json:"inviter_id,omitempty"`
|
||||
AffCount int `json:"aff_count"`
|
||||
AffQuota float64 `json:"aff_quota"`
|
||||
AffFrozenQuota float64 `json:"aff_frozen_quota"`
|
||||
AffHistoryQuota float64 `json:"aff_history_quota"`
|
||||
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
|
||||
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
|
||||
@ -95,7 +98,9 @@ type AffiliateRepository interface {
|
||||
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
|
||||
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
|
||||
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
|
||||
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
|
||||
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
|
||||
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
|
||||
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
|
||||
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
|
||||
|
||||
@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64
|
||||
}
|
||||
|
||||
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
|
||||
// Lazy thaw: move any matured frozen quota to available before reading.
|
||||
if s != nil && s.repo != nil {
|
||||
// best-effort: thaw failure is non-fatal
|
||||
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
|
||||
}
|
||||
|
||||
summary, err := s.EnsureUserAffiliate(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
|
||||
InviterID: summary.InviterID,
|
||||
AffCount: summary.AffCount,
|
||||
AffQuota: summary.AffQuota,
|
||||
AffFrozenQuota: summary.AffFrozenQuota,
|
||||
AffHistoryQuota: summary.AffHistoryQuota,
|
||||
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
|
||||
Invitees: invitees,
|
||||
@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// 有效期检查:超过返利有效期后不再产生返利
|
||||
if s.settingService != nil {
|
||||
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
|
||||
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
|
||||
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
|
||||
if rebate <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
|
||||
// 单人上限检查:精确截断到剩余额度
|
||||
if s.settingService != nil {
|
||||
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
|
||||
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if existing >= perInviteeCap {
|
||||
return 0, nil
|
||||
}
|
||||
if remaining := perInviteeCap - existing; rebate > remaining {
|
||||
rebate = roundTo(remaining, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var freezeHours int
|
||||
if s.settingService != nil {
|
||||
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
|
||||
}
|
||||
|
||||
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
user *User,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
affiliateCode string,
|
||||
) error {
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
}
|
||||
} else {
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
}
|
||||
}
|
||||
|
||||
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
|
||||
// for an OAuth-registered user. Failures are logged but never block registration.
|
||||
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
|
||||
if s.affiliateService == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
|
||||
}
|
||||
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
|
||||
@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.Equal(t, existing.ID, user.ID)
|
||||
|
||||
@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
|
||||
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// API Key 限速缓存方法
|
||||
// ============================================
|
||||
|
||||
@ -17,7 +17,7 @@ const (
|
||||
// ClaudeTokenCache token cache interface.
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||
return "", errors.New("not an anthropic oauth or service account")
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return p.getServiceAccountAccessToken(ctx, account)
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
return "", errors.New("not an anthropic oauth or service account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
|
||||
@ -20,10 +20,15 @@ const (
|
||||
|
||||
// Affiliate rebate settings
|
||||
const (
|
||||
AffiliateRebateRateDefault = 20.0
|
||||
AffiliateRebateRateMin = 0.0
|
||||
AffiliateRebateRateMax = 100.0
|
||||
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
|
||||
AffiliateRebateRateDefault = 20.0
|
||||
AffiliateRebateRateMin = 0.0
|
||||
AffiliateRebateRateMax = 100.0
|
||||
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
|
||||
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
|
||||
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
|
||||
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
|
||||
AffiliateRebateDurationDaysMax = 3650 // ~10 年
|
||||
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
@ -37,11 +42,12 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@ -98,6 +104,9 @@ const (
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
|
||||
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
|
||||
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
|
||||
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
|
||||
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@ -299,6 +308,12 @@ const (
|
||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||
|
||||
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
|
||||
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
|
||||
// targets OpenAI's body-level service_tier field instead of Claude's
|
||||
// anthropic-beta header.
|
||||
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
|
||||
|
||||
// =========================
|
||||
// Claude Code Version Check
|
||||
// =========================
|
||||
@ -322,6 +337,8 @@ const (
|
||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
|
||||
|
||||
// Balance Low Notification
|
||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
c.Request.Header.Set("Authorization", "Bearer inbound-token")
|
||||
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
|
||||
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeServiceAccount,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "vertex-proj",
|
||||
"location": "us-east5",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
svc := &GatewayService{}
|
||||
req, err := svc.buildUpstreamRequest(
|
||||
context.Background(),
|
||||
c,
|
||||
account,
|
||||
body,
|
||||
"vertex-token",
|
||||
"service_account",
|
||||
"claude-sonnet-4-5@20250929",
|
||||
false,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
|
||||
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
|
||||
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
|
||||
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
|
||||
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
|
||||
|
||||
got := readRequestBodyForTest(t, req)
|
||||
require.Equal(t, "", gjson.GetBytes(got, "model").String())
|
||||
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
|
||||
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
|
||||
}
|
||||
|
||||
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
|
||||
t.Helper()
|
||||
require.NotNil(t, req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
return body
|
||||
}
|
||||
@ -1,13 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type gatewayTTLSettingRepo struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
if r == nil {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
v, ok := r.data[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
r.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
if v, ok := r.data[key]; ok {
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r == nil {
|
||||
return errors.New("setting repo is nil")
|
||||
}
|
||||
if r.data == nil {
|
||||
r.data = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.data[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
if r == nil {
|
||||
return result, nil
|
||||
}
|
||||
for key, value := range r.data {
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
|
||||
if r != nil {
|
||||
delete(r.data, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
|
||||
t.Helper()
|
||||
|
||||
@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
|
||||
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||
}
|
||||
|
||||
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
|
||||
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
||||
|
||||
result := injectAnthropicCacheControlTTL1h(body)
|
||||
resultStr := string(result)
|
||||
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
|
||||
require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||
|
||||
target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget5m, target)
|
||||
|
||||
account.Extra = map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "1h",
|
||||
}
|
||||
target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, cacheTTLTarget1h, target)
|
||||
}
|
||||
|
||||
func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
|
||||
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
|
||||
}}
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
svc := &GatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
|
||||
|
||||
repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||
require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
|
||||
}
|
||||
|
||||
@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -20,6 +21,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@ -63,6 +65,11 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTTLTarget5m = "5m"
|
||||
cacheTTLTarget1h = "1h"
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@ -335,9 +342,8 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`)
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@ -677,15 +683,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
|
||||
uid := ParseMetadataUserID(parsed.MetadataUserID)
|
||||
if uid != nil && uid.SessionID != "" {
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "metadata_user_id",
|
||||
"session_id", uid.SessionID,
|
||||
"device_id", uid.DeviceID,
|
||||
"is_new_format", uid.IsNewFormat,
|
||||
)
|
||||
return uid.SessionID
|
||||
}
|
||||
slog.Info("sticky.hash_metadata_parse_failed",
|
||||
"metadata_user_id", parsed.MetadataUserID,
|
||||
"parsed_nil", uid == nil,
|
||||
)
|
||||
}
|
||||
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
hash := s.hashContent(cacheableContent)
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "cacheable_content",
|
||||
"hash", hash,
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
@ -725,7 +747,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
}
|
||||
if combined.Len() > 0 {
|
||||
return s.hashContent(combined.String())
|
||||
hash := s.hashContent(combined.String())
|
||||
slog.Info("sticky.hash_source",
|
||||
"source", "message_content_fallback",
|
||||
"hash", hash,
|
||||
"content_len", combined.Len(),
|
||||
)
|
||||
return hash
|
||||
}
|
||||
|
||||
return ""
|
||||
@ -1432,14 +1460,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
var stickySource string
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
stickySource = "prefetch"
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
stickySource = "cache"
|
||||
}
|
||||
}
|
||||
|
||||
// [DEBUG-STICKY] 调度器入口日志
|
||||
slog.Info("sticky.scheduler_entry",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"session_hash", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"sticky_source", stickySource,
|
||||
"model", requestedModel,
|
||||
"load_batch", cfg.LoadBatchEnabled,
|
||||
"has_concurrency_svc", s.concurrencyService != nil,
|
||||
"excluded_count", len(excludedIDs),
|
||||
)
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
@ -1615,6 +1658,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && stickyAccountID > 0 {
|
||||
slog.Debug("sticky.layer1_5_checking",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"in_routing_list", containsInt64(routingAccountIDs, stickyAccountID),
|
||||
"is_excluded", isExcluded(stickyAccountID),
|
||||
"in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
@ -1638,6 +1688,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_hit",
|
||||
"account_id", stickyAccountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
@ -1788,27 +1843,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||
if clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_clear",
|
||||
"account_id", accountID,
|
||||
"reason", "should_clear_sticky_session",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(account) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
|
||||
// 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的
|
||||
// accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存
|
||||
// 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。
|
||||
platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed)
|
||||
modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)
|
||||
modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel)
|
||||
quotaOK := s.isAccountSchedulableForQuota(account)
|
||||
windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true)
|
||||
rpmOK := s.isAccountSchedulableForRPM(ctx, account, true)
|
||||
schedulable := s.isAccountSchedulableForSelection(account)
|
||||
|
||||
slog.Debug("sticky.layer1_5_no_routing_checks",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"clear_sticky", clearSticky,
|
||||
"schedulable", schedulable,
|
||||
"platform_ok", platformOK,
|
||||
"model_supported", modelSupported,
|
||||
"model_schedulable", modelSchedulable,
|
||||
"quota_ok", quotaOK,
|
||||
"window_cost_ok", windowCostOK,
|
||||
"rpm_ok", rpmOK,
|
||||
)
|
||||
|
||||
if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "session_limit",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "slot_acquired",
|
||||
)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_slot_busy",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@ -1817,6 +1910,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_hit",
|
||||
"account_id", accountID,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"result", "wait_plan",
|
||||
)
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
@ -1825,12 +1923,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if !clearSticky {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "gate_check_failed",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("sticky.layer1_5_no_routing_miss",
|
||||
"account_id", accountID,
|
||||
"reason", "account_not_in_map",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if len(routingAccountIDs) == 0 && sessionHash != "" {
|
||||
slog.Debug("sticky.layer1_5_no_routing_skip",
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(),
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"reason", func() string {
|
||||
if stickyAccountID == 0 {
|
||||
return "no_sticky_binding"
|
||||
}
|
||||
return "sticky_account_excluded"
|
||||
}(),
|
||||
)
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
slog.Debug("sticky.layer2_fallback",
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"sticky_account_id", stickyAccountID,
|
||||
"reason", "sticky_not_used_falling_back_to_load_balance",
|
||||
"total_accounts", len(accounts),
|
||||
)
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@ -3654,7 +3782,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
|
||||
} else {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
@ -3674,6 +3806,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||
case AccountTypeServiceAccount:
|
||||
if account.Platform != PlatformAnthropic {
|
||||
return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
|
||||
}
|
||||
if s.claudeTokenProvider == nil {
|
||||
return "", "", errors.New("claude token provider not configured")
|
||||
}
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "service_account", nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@ -3781,16 +3925,6 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
return ParseMetadataUserID(metadataUserID) != nil
|
||||
}
|
||||
|
||||
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return true
|
||||
}
|
||||
if parsed == nil || c == nil {
|
||||
return false
|
||||
}
|
||||
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
}
|
||||
|
||||
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
||||
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
|
||||
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
|
||||
@ -4153,6 +4287,87 @@ func enforceCacheControlLimit(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
|
||||
// 仅修改已经存在的 cache_control,不新增缓存断点。
|
||||
func injectAnthropicCacheControlTTL1h(body []byte) []byte {
|
||||
return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
|
||||
}
|
||||
|
||||
func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
|
||||
if len(body) == 0 || ttl == "" {
|
||||
return body
|
||||
}
|
||||
out := body
|
||||
var paths []string
|
||||
addPath := func(path string, value gjson.Result) {
|
||||
cc := value.Get("cache_control")
|
||||
if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
|
||||
return
|
||||
}
|
||||
if cc.Get("ttl").String() == ttl {
|
||||
return
|
||||
}
|
||||
paths = append(paths, path+".cache_control.ttl")
|
||||
}
|
||||
|
||||
if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
|
||||
paths = append(paths, "cache_control.ttl")
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(body, "system")
|
||||
if system.IsArray() {
|
||||
idx := -1
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("system.%d", idx), block)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if messages.IsArray() {
|
||||
msgIdx := -1
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
msgIdx++
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
contentIdx := -1
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
contentIdx++
|
||||
addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.IsArray() {
|
||||
idx := -1
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
idx++
|
||||
addPath(fmt.Sprintf("tools.%d", idx), tool)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if next, err := sjson.SetBytes(out, path, ttl); err == nil {
|
||||
out = next
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
|
||||
if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@ -4286,6 +4501,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
if candidate, matched := account.ResolveMappedModel(reqModel); matched {
|
||||
mappedModel = candidate
|
||||
mappingSource = "account"
|
||||
} else {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "vertex"
|
||||
}
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(reqModel)
|
||||
if normalized != reqModel {
|
||||
@ -4300,6 +4527,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
|
||||
body = injectAnthropicCacheControlTTL1h(body)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@ -5799,6 +6030,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||
if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
|
||||
}
|
||||
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@ -6010,6 +6245,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
modelID string,
|
||||
reqStream bool,
|
||||
) (*http.Request, error) {
|
||||
vertexBody, err := buildVertexAnthropicRequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, vertexBody)
|
||||
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||
if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
|
||||
continue
|
||||
}
|
||||
wireKey := resolveWireCasing(key)
|
||||
for _, v := range values {
|
||||
addHeaderRaw(req.Header, wireKey, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Del("cookie")
|
||||
req.Header.Del("anthropic-version")
|
||||
setHeaderRaw(req.Header, "authorization", "Bearer "+token)
|
||||
setHeaderRaw(req.Header, "content-type", "application/json")
|
||||
|
||||
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
|
||||
"url": req.URL.String(),
|
||||
"token_type": "service_account",
|
||||
"model": modelID,
|
||||
"stream": strconv.FormatBool(reqStream),
|
||||
})
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||
@ -6567,6 +6856,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
|
||||
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||
// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
|
||||
// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
|
||||
func sanitizeStreamError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return "unexpected EOF"
|
||||
case errors.Is(err, io.EOF):
|
||||
return "EOF"
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "canceled"
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return "deadline exceeded"
|
||||
case errors.Is(err, syscall.ECONNRESET):
|
||||
return "connection reset by peer"
|
||||
case errors.Is(err, syscall.ECONNABORTED):
|
||||
return "connection aborted"
|
||||
case errors.Is(err, syscall.ETIMEDOUT):
|
||||
return "connection timed out"
|
||||
case errors.Is(err, syscall.EPIPE):
|
||||
return "broken pipe"
|
||||
case errors.Is(err, syscall.ECONNREFUSED):
|
||||
return "connection refused"
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
if netErr.Op != "" {
|
||||
return netErr.Op + " timeout"
|
||||
}
|
||||
return "i/o timeout"
|
||||
}
|
||||
if netErr.Op != "" {
|
||||
return netErr.Op + " network error"
|
||||
}
|
||||
}
|
||||
return "upstream connection error"
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
@ -7004,14 +7336,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
|
||||
// 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":<reason>,"message":<message>}}
|
||||
// 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
|
||||
// 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
sendErrorEvent := func(reason, message string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
if message == "" {
|
||||
message = reason
|
||||
}
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": reason,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
|
||||
body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@ -7088,9 +7437,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
@ -7171,10 +7520,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
// 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
|
||||
// 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
|
||||
// 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
|
||||
// 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
|
||||
// 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
|
||||
// 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
|
||||
disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
|
||||
if !c.Writer.Written() {
|
||||
logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": "upstream_disconnected",
|
||||
"message": disconnectMsg,
|
||||
},
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: body,
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
sendErrorEvent("stream_read_error", disconnectMsg)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
line := ev.line
|
||||
@ -7233,7 +7604,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
sendErrorEvent("stream_timeout")
|
||||
sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
@ -7475,6 +7846,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
return account.GetCacheTTLOverrideTarget(), true
|
||||
}
|
||||
if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
|
||||
return cacheTTLTarget5m, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
@ -7511,9 +7895,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
}
|
||||
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
|
||||
@ -8081,10 +8465,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致。
|
||||
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
|
||||
applyCacheTTLOverride(&result.Usage, overrideTarget)
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
|
||||
@ -4,9 +4,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
|
||||
}
|
||||
|
||||
// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
|
||||
// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
|
||||
func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
|
||||
|
||||
// ResponseBody 必须是 Anthropic 标准 error 格式:
|
||||
// 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
|
||||
// 2) error.type 标记为 upstream_disconnected
|
||||
extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
|
||||
require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
|
||||
require.Contains(t, extractedMsg, "upstream stream disconnected")
|
||||
require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
|
||||
|
||||
// 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
|
||||
require.NotContains(t, rec.Body.String(), "stream_read_error")
|
||||
}
|
||||
|
||||
// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
|
||||
// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
|
||||
func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{
|
||||
payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
|
||||
err: io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
|
||||
require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
|
||||
|
||||
// 不应被错误地包成 failover error
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
|
||||
|
||||
// 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
|
||||
// error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
|
||||
require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
|
||||
require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
|
||||
require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
|
||||
}
|
||||
|
||||
// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
|
||||
// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
|
||||
// failover ResponseBody 或 SSE error 帧返回给客户端。
|
||||
func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
|
||||
src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||
require.NoError(t, err)
|
||||
dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw := &net.OpError{
|
||||
Op: "read",
|
||||
Net: "tcp",
|
||||
Source: src,
|
||||
Addr: dst,
|
||||
Err: syscall.ECONNRESET,
|
||||
}
|
||||
|
||||
// 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
|
||||
require.Contains(t, raw.Error(), "10.0.0.1")
|
||||
require.Contains(t, raw.Error(), "52.1.2.3")
|
||||
|
||||
got := sanitizeStreamError(raw)
|
||||
require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
|
||||
require.NotContains(t, got, "54321", "不得泄露源端口")
|
||||
require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
|
||||
require.NotContains(t, got, "443", "不得泄露上游端口")
|
||||
require.Equal(t, "connection reset by peer", got)
|
||||
}
|
||||
|
||||
func TestSanitizeStreamError_KnownErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
|
||||
{"EOF", io.EOF, "EOF"},
|
||||
{"context canceled", context.Canceled, "canceled"},
|
||||
{"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
|
||||
{"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
|
||||
{"EPIPE", syscall.EPIPE, "broken pipe"},
|
||||
{"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
|
||||
{"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
|
||||
{"nil 返回空串", nil, ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require.Equal(t, tc.want, sanitizeStreamError(tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
|
||||
// 时携带内部地址信息。
|
||||
func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newMinimalGatewayService()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
|
||||
dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
|
||||
netErr := &net.OpError{
|
||||
Op: "read",
|
||||
Net: "tcp",
|
||||
Source: src,
|
||||
Addr: dst,
|
||||
Err: syscall.ECONNRESET,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: &streamReadCloser{err: netErr},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
require.Error(t, err)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.True(t, errors.As(err, &failoverErr))
|
||||
|
||||
body := string(failoverErr.ResponseBody)
|
||||
require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
|
||||
require.NotContains(t, body, "54321")
|
||||
require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
|
||||
require.NotContains(t, body, "443")
|
||||
// 仍然包含可诊断的根因
|
||||
require.Contains(t, body, "connection reset by peer")
|
||||
require.Contains(t, body, "upstream stream disconnected")
|
||||
}
|
||||
|
||||
@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
}
|
||||
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
||||
return 3
|
||||
case AccountTypeServiceAccount:
|
||||
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
|
||||
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
|
||||
return 999
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case AccountTypeServiceAccount:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return upstreamReq, "x-request-id", nil
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
body = ensureGeminiFunctionCallThoughtSignatures(body)
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case AccountTypeServiceAccount:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return upstreamReq, "x-request-id", nil
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
default:
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
|
||||
}
|
||||
|
||||
@ -15,7 +15,7 @@ const (
|
||||
geminiTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
||||
// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
|
||||
type GeminiTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache
|
||||
@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
|
||||
return "", errors.New("not a gemini oauth or service account")
|
||||
}
|
||||
if account.Type == AccountTypeServiceAccount {
|
||||
return p.getServiceAccountAccessToken(ctx, account)
|
||||
}
|
||||
|
||||
cacheKey := GeminiTokenCacheKey(account)
|
||||
@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
if account != nil && account.Type == AccountTypeServiceAccount {
|
||||
if key, err := parseVertexServiceAccountKey(account); err == nil {
|
||||
return vertexServiceAccountCacheKey(account, key)
|
||||
}
|
||||
}
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "gemini:" + projectID
|
||||
|
||||
@ -53,6 +53,23 @@ const (
|
||||
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n</sub2api-codex-spark-image-unsupported>"
|
||||
)
|
||||
|
||||
var openAIChatGPTInternalUnsupportedFields = []string{
|
||||
"user",
|
||||
"metadata",
|
||||
"prompt_cache_retention",
|
||||
"safety_identifier",
|
||||
"stream_options",
|
||||
}
|
||||
|
||||
var openAICodexOAuthUnsupportedFields = append([]string{
|
||||
"max_output_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
}, openAIChatGPTInternalUnsupportedFields...)
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
}
|
||||
|
||||
// Strip parameters unsupported by codex models via the Responses API.
|
||||
for _, key := range []string{
|
||||
"max_output_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
// prompt_cache_retention is a newer Responses API parameter (cache TTL).
|
||||
// The ChatGPT internal Codex endpoint rejects it with
|
||||
// "Unsupported parameter: prompt_cache_retention". Defense-in-depth
|
||||
// for any OAuth path that reaches this transform — the Cursor
|
||||
// Responses-shape short-circuit in ForwardAsChatCompletions strips
|
||||
// it earlier too, but we keep this line so other OAuth callers are
|
||||
// equally protected.
|
||||
"prompt_cache_retention",
|
||||
} {
|
||||
// Strip parameters unsupported by ChatGPT internal Codex endpoint.
|
||||
for _, key := range openAICodexOAuthUnsupportedFields {
|
||||
if _, ok := reqBody[key]; ok {
|
||||
delete(reqBody, key)
|
||||
result.Modified = true
|
||||
@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
reqBody["tool_choice"] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
},
|
||||
"name": name,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
|
||||
if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
|
||||
if choiceType == "" {
|
||||
return false
|
||||
}
|
||||
modified := false
|
||||
if choiceType == "function" {
|
||||
name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
|
||||
if name == "" {
|
||||
if function, ok := choiceMap["function"].(map[string]any); ok {
|
||||
name = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
|
||||
choiceMap["name"] = name
|
||||
modified = true
|
||||
}
|
||||
if _, ok := choiceMap["function"]; ok {
|
||||
delete(choiceMap, "function")
|
||||
modified = true
|
||||
}
|
||||
if !codexToolsContainFunctionName(reqBody["tools"], name) {
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
return modified
|
||||
}
|
||||
if codexToolsContainType(reqBody["tools"], choiceType) {
|
||||
return modified
|
||||
}
|
||||
reqBody["tool_choice"] = "auto"
|
||||
return true
|
||||
}
|
||||
@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func codexToolsContainFunctionName(rawTools any, name string) bool {
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok || strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
normalizedName := strings.TrimSpace(name)
|
||||
for _, rawTool := range tools {
|
||||
tool, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
|
||||
continue
|
||||
}
|
||||
toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
|
||||
if toolName == "" {
|
||||
if function, ok := tool["function"].(map[string]any); ok {
|
||||
toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
|
||||
}
|
||||
}
|
||||
if toolName == normalizedName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
|
||||
if len(input) == 0 {
|
||||
return input, false
|
||||
@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// chatgpt.com codex backend (OAuth path) does not persist reasoning
|
||||
// items because applyCodexOAuthTransform forces store=false. Any rs_*
|
||||
// reference replayed in input is guaranteed to 404 upstream
|
||||
// ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
|
||||
if typ == "reasoning" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
|
||||
require.Equal(t, "custom", choice["type"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": "shell"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
choice, ok := reqBody["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function", choice["type"])
|
||||
require.Equal(t, "shell", choice["name"])
|
||||
require.NotContains(t, choice, "function")
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "name": "shell"},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": "missing"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
require.Equal(t, "auto", reqBody["tool_choice"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
|
||||
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.4",
|
||||
"user": "user_123",
|
||||
"metadata": map[string]any{"trace_id": "abc"},
|
||||
"prompt_cache_retention": "24h",
|
||||
"safety_identifier": "sid",
|
||||
"stream_options": map[string]any{"include_usage": true},
|
||||
"input": []any{
|
||||
map[string]any{"role": "user", "content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true, false)
|
||||
|
||||
require.True(t, result.Modified)
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
require.NotContains(t, reqBody, field)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
|
||||
// Reasoning items in input[] reference rs_* IDs that were emitted by
|
||||
// chatgpt.com under store=false (forced by applyCodexOAuthTransform).
|
||||
// They are never persisted upstream, so forwarding them produces a
|
||||
// guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
|
||||
// regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
|
||||
|
||||
build := func() []any {
|
||||
return []any{
|
||||
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||
map[string]any{
|
||||
"type": "reasoning",
|
||||
"id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
|
||||
"summary": []any{},
|
||||
},
|
||||
map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
|
||||
}
|
||||
}
|
||||
|
||||
for _, preserve := range []bool{true, false} {
|
||||
preserve := preserve
|
||||
t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
|
||||
filtered := filterCodexInput(build(), preserve)
|
||||
|
||||
for _, raw := range filtered {
|
||||
item, ok := raw.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "reasoning", item["type"],
|
||||
"reasoning items must be dropped from input on the OAuth path")
|
||||
if id, ok := item["id"].(string); ok {
|
||||
require.False(t, strings.HasPrefix(id, "rs_"),
|
||||
"no item carrying an rs_* id should survive the filter")
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check: the non-reasoning items should still be present.
|
||||
gotTypes := make(map[string]int)
|
||||
for _, raw := range filtered {
|
||||
item, ok := raw.(map[string]any)
|
||||
require.True(t, ok)
|
||||
typ, ok := item["type"].(string)
|
||||
require.True(t, ok)
|
||||
gotTypes[typ]++
|
||||
}
|
||||
require.Equal(t, 1, gotTypes["message"])
|
||||
require.Equal(t, 1, gotTypes["function_call"])
|
||||
require.Equal(t, 1, gotTypes["function_call_output"])
|
||||
require.Equal(t, 0, gotTypes["reasoning"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
286
backend/internal/service/openai_fast_policy_test.go
Normal file
286
backend/internal/service/openai_fast_policy_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIFastPolicyRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
if s.values == nil {
|
||||
s.values = map[string]string{}
|
||||
}
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
||||
t.Helper()
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
if settings != nil {
|
||||
raw, err := json.Marshal(settings)
|
||||
require.NoError(t, err)
|
||||
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
||||
}
|
||||
return &OpenAIGatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
||||
// 是用户级开关,与 model 正交。
|
||||
// gpt-5.5 + priority → filter
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5-turbo → filter
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5 + flex → pass (tier doesn't match)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
|
||||
// empty tier → pass
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is not allowed",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionBlock, action)
|
||||
require.Equal(t, "fast mode is not allowed", msg)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeOAuth,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
|
||||
// OAuth account → rule matches
|
||||
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// API Key account → rule skipped → pass
|
||||
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// gpt-5.5 fast → service_tier stripped
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// Client sending "fast" (alias for priority) also filtered
|
||||
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
||||
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// No service_tier → no-op
|
||||
body = []byte(`{"model":"gpt-5.5"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
||||
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
||||
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
||||
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err, "tier %q should pass without error", tier)
|
||||
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
||||
"tier %q should be preserved in body under default rule", tier)
|
||||
}
|
||||
|
||||
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
||||
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
||||
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
||||
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
||||
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`,
|
||||
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
||||
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
||||
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
||||
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
||||
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// normalize 阶段会将未知值剥离
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
|
||||
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
||||
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.Error(t, err)
|
||||
var blocked *OpenAIFastBlockedError
|
||||
require.True(t, errors.As(err, &blocked))
|
||||
require.Contains(t, blocked.Message, "fast mode is blocked")
|
||||
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
||||
}
|
||||
|
||||
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
// Invalid action rejected
|
||||
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: "bogus",
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Invalid service_tier rejected
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: "turbo",
|
||||
Action: BetaPolicyActionPass,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Valid settings persisted
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got.Rules, 1)
|
||||
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
||||
}
|
||||
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
}
|
||||
|
||||
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
|
||||
@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "flex", req.ServiceTier)
|
||||
|
||||
// OpenAI 官方合法 tier 应被透传保留。
|
||||
req.ServiceTier = "auto"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "auto", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "default"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "default", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "scale"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "scale", req.ServiceTier)
|
||||
|
||||
// 真未知值仍被剥离。
|
||||
req.ServiceTier = "turbo"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Empty(t, req.ServiceTier)
|
||||
}
|
||||
|
||||
@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
require.Equal(t, "flex", tier)
|
||||
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "auto", tier)
|
||||
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "default", tier)
|
||||
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "scale", tier)
|
||||
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// 真未知值才会被删除。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
|
||||
@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
}
|
||||
|
||||
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
|
||||
// on the body-level service_tier field (priority/flex).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
|
||||
@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
t.Run("openai official tiers preserved", func(t *testing.T) {
|
||||
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
|
||||
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
|
||||
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
|
||||
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
|
||||
got := normalizeOpenAIServiceTier(tier)
|
||||
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
|
||||
require.Equal(t, tier, *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
|
||||
@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
|
||||
resolver *ModelPricingResolver
|
||||
channelService *ChannelService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
settingService *SettingService
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
|
||||
resolver *ModelPricingResolver,
|
||||
channelService *ChannelService,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
settingService *SettingService,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
|
||||
resolver: resolver,
|
||||
channelService: channelService,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
settingService: settingService,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
|
||||
return sessionID
|
||||
}
|
||||
|
||||
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
|
||||
// client session signals. It intentionally skips content-derived fallback and is
|
||||
// used by stateless endpoints such as /v1/images.
|
||||
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
|
||||
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
||||
return currentHash
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = deriveOpenAIContentSessionSeed(body)
|
||||
}
|
||||
@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
|
||||
// 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
|
||||
// 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
|
||||
// fast 时在此生效。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
|
||||
// normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
|
||||
// chat-completions / messages 入口保持一致,避免不同入口因为模型
|
||||
// 维度不同而出现 whitelist 命中差异。
|
||||
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
|
||||
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
|
||||
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
|
||||
// 行为,这里手工实现等效逻辑。
|
||||
if rawTier, ok := reqBody["service_tier"].(string); ok {
|
||||
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
|
||||
}
|
||||
blocked := &OpenAIFastBlockedError{Message: msg}
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
return nil, blocked
|
||||
case BetaPolicyActionFilter:
|
||||
delete(reqBody, "service_tier")
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
default:
|
||||
// pass:若客户端传的是别名 "fast",归一化为 "priority"
|
||||
// 后写回 body,确保上游收到的是其能识别的规范值。
|
||||
if normTier != rawTier {
|
||||
reqBody["service_tier"] = normTier
|
||||
bodyModified = true
|
||||
markPatchSet("service_tier", normTier)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
serializedByPatch := false
|
||||
@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
body = sanitizedBody
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
|
||||
// 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
|
||||
// OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
|
||||
// 这样可以与 chat-completions / messages / native /responses 入口的
|
||||
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
|
||||
// model 字段时退回 reqModel。
|
||||
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||
if policyModel == "" {
|
||||
policyModel = reqModel
|
||||
}
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
body = updatedBody
|
||||
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
@ -4841,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
|
||||
}
|
||||
|
||||
normalized := []byte(`{}`)
|
||||
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
|
||||
// Keep the current Codex /compact schema while still dropping request-scoped
|
||||
// fields such as prompt_cache_key, store, and stream.
|
||||
for _, field := range []string{
|
||||
"model",
|
||||
"input",
|
||||
"instructions",
|
||||
"tools",
|
||||
"parallel_tool_calls",
|
||||
"reasoning",
|
||||
"text",
|
||||
"previous_response_id",
|
||||
} {
|
||||
value := gjson.GetBytes(body, field)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
@ -5454,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
|
||||
}
|
||||
|
||||
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
||||
// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
|
||||
// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
|
||||
// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false
|
||||
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
|
||||
if len(body) == 0 {
|
||||
return body, false, nil
|
||||
@ -5463,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
|
||||
normalized := body
|
||||
changed := false
|
||||
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
if value := gjson.GetBytes(normalized, field); !value.Exists() {
|
||||
continue
|
||||
}
|
||||
next, err := sjson.DeleteBytes(normalized, field)
|
||||
if err != nil {
|
||||
return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
|
||||
}
|
||||
normalized = next
|
||||
changed = true
|
||||
}
|
||||
|
||||
if compact {
|
||||
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
|
||||
next, err := sjson.DeleteBytes(normalized, "store")
|
||||
@ -5567,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
|
||||
if value == "fast" {
|
||||
value = "priority"
|
||||
}
|
||||
// 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
|
||||
// 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
|
||||
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
|
||||
// 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
|
||||
switch value {
|
||||
case "priority", "flex":
|
||||
case "priority", "flex", "auto", "default", "scale":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
|
||||
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
|
||||
type OpenAIFastBlockedError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
|
||||
|
||||
// evaluateOpenAIFastPolicy returns the action and error message that should be
|
||||
// applied for a request with the given account/model/service_tier. When the
|
||||
// policy service is unavailable or no rule matches, it returns
|
||||
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
|
||||
//
|
||||
// Matching rules:
|
||||
// - Scope filters by account type (all / oauth / apikey / bedrock)
|
||||
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
|
||||
// - ModelWhitelist narrows the rule to specific models; FallbackAction
|
||||
// handles the non-matching case (default: pass)
|
||||
//
|
||||
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
|
||||
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
|
||||
// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
|
||||
// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
|
||||
// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
|
||||
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
|
||||
// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
|
||||
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
|
||||
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
|
||||
if s == nil || s.settingService == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
tier := strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
if tier == "" {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings := openAIFastPolicySettingsFromContext(ctx)
|
||||
if settings == nil {
|
||||
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
|
||||
if err != nil || fetched == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings = fetched
|
||||
}
|
||||
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
|
||||
}
|
||||
|
||||
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
|
||||
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
|
||||
// the settingService on every frame. See WSSession entry and
|
||||
// openAIFastPolicySettingsFromContext for the caching glue.
|
||||
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
|
||||
if settings == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
isOAuth := account != nil && account.IsOAuth()
|
||||
isBedrock := account != nil && account.IsBedrock()
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
|
||||
continue
|
||||
}
|
||||
eff := BetaPolicyRule{
|
||||
Action: rule.Action,
|
||||
ErrorMessage: rule.ErrorMessage,
|
||||
ModelWhitelist: rule.ModelWhitelist,
|
||||
FallbackAction: rule.FallbackAction,
|
||||
FallbackErrorMessage: rule.FallbackErrorMessage,
|
||||
}
|
||||
return resolveRuleAction(eff, model)
|
||||
}
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
|
||||
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
|
||||
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
|
||||
//
|
||||
// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
|
||||
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
|
||||
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
|
||||
// 可以通过踢断 session 强制刷新。
|
||||
type openAIFastPolicyCtxKeyType struct{}
|
||||
|
||||
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
|
||||
|
||||
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
|
||||
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
|
||||
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
|
||||
if ctx == nil || settings == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
|
||||
}
|
||||
|
||||
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
|
||||
// body. When action=filter it removes the service_tier field; when
|
||||
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
|
||||
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
|
||||
// rewriting the body so the upstream receives a slug it recognizes.
|
||||
//
|
||||
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
|
||||
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
|
||||
// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
|
||||
// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
|
||||
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(body, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return body, &OpenAIFastBlockedError{Message: msg}
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("strip service_tier from body: %w", err)
|
||||
}
|
||||
return trimmed, nil
|
||||
default:
|
||||
// pass:把别名(如 "fast")写回为规范值("priority")。
|
||||
if normTier == rawTier {
|
||||
return body, nil
|
||||
}
|
||||
updated, err := sjson.SetBytes(body, "service_tier", normTier)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
}
|
||||
|
||||
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
|
||||
// request blocked by the OpenAI fast policy.
|
||||
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
|
||||
if c == nil || err == nil {
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
|
||||
// against a single client→upstream WebSocket frame whose top-level
|
||||
// "type"=="response.create". It mirrors the HTTP-side
|
||||
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
|
||||
// WS payload:
|
||||
//
|
||||
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
|
||||
// - filter: returns a copy with top-level service_tier removed
|
||||
// - block: returns (frame, *OpenAIFastBlockedError)
|
||||
//
|
||||
// Only frames whose "type" field strictly equals "response.create" are
|
||||
// inspected/mutated. Any other frame type — including the empty string —
|
||||
// passes through untouched. The OpenAI Realtime client-event spec requires
|
||||
// "type" to be set, so an empty type is treated as a malformed frame we do
|
||||
// not police; the upstream is the source of truth for rejecting it.
|
||||
//
|
||||
// service_tier lives at the top level of response.create — same as the
|
||||
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
|
||||
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
|
||||
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
|
||||
// to inspect / strip the top-level field; there is no nested form in the
|
||||
// schema today.
|
||||
//
|
||||
// The caller is responsible for choosing the upstream model passed in —
|
||||
// this helper does not re-derive it.
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
model string,
|
||||
frame []byte,
|
||||
) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if len(frame) == 0 {
|
||||
return frame, nil, nil
|
||||
}
|
||||
if !gjson.ValidBytes(frame) {
|
||||
return frame, nil, nil
|
||||
}
|
||||
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
|
||||
// Strict match: only response.create is policy-checked. Empty / other
|
||||
// types pass through untouched so we never accidentally strip fields
|
||||
// from response.cancel, conversation.item.create, or any future
|
||||
// client-event the spec adds. The Realtime spec requires "type" on
|
||||
// every client event, so an empty type is malformed input — let the
|
||||
// upstream reject it rather than guessing at our layer.
|
||||
if frameType != "response.create" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(frame, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return frame, &OpenAIFastBlockedError{Message: msg}, nil
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
|
||||
if err != nil {
|
||||
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
|
||||
}
|
||||
return trimmed, nil, nil
|
||||
default:
|
||||
return frame, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
|
||||
// server-emitted error event. Matches the loose "evt_<rand>" convention used
|
||||
// by upstream Realtime servers; the exact value is not load-bearing and is
|
||||
// only required for client-side log correlation. We reuse the existing
|
||||
// google/uuid dependency rather than pulling a new one.
|
||||
func newOpenAIFastPolicyWSEventID() string {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
// Extremely unlikely; fall back to a fixed prefix so the field is
|
||||
// still non-empty and the schema stays self-consistent.
|
||||
return "evt_openai_fast_policy"
|
||||
}
|
||||
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
|
||||
// canonical form, mirroring what real Realtime traces look like.
|
||||
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
|
||||
}
|
||||
|
||||
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
|
||||
// style "error" event payload for a request blocked by the OpenAI fast
|
||||
// policy. The shape mirrors Realtime error events as observed in upstream
|
||||
// traces and per the spec's server "error" event:
|
||||
//
|
||||
// {
|
||||
// "event_id": "evt_<random>",
|
||||
// "type": "error",
|
||||
// "error": {
|
||||
// "type": "invalid_request_error",
|
||||
// "code": "policy_violation",
|
||||
// "message": "..."
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// event_id lets clients correlate the rejection in their logs; "code" gives
|
||||
// programmatic clients a stable identifier (HTTP-side equivalent is the
|
||||
// 403 permission_error JSON body).
|
||||
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
eventID := newOpenAIFastPolicyWSEventID()
|
||||
payload, mErr := json.Marshal(map[string]any{
|
||||
"event_id": eventID,
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"code": "policy_violation",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
if mErr != nil {
|
||||
// Fallback to a minimal hand-rolled payload; Marshal of the literal
|
||||
// shape above should never fail in practice.
|
||||
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
|
||||
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
|
||||
return body, false, nil
|
||||
|
||||
@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := &OpenAIGatewayService{}
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
|
||||
|
||||
t.Run("stateless image body stays unstuck", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
|
||||
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("header overrides body", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
c.Request.Header.Set("session_id", "header-session")
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@ -1756,6 +1791,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAICompactRequestBody(body)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
|
||||
require.True(t, gjson.GetBytes(normalized, "tools").Exists())
|
||||
require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
|
||||
require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
|
||||
require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
|
||||
require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
|
||||
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/edits",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
require.Equal(t,
|
||||
"https://image-upstream.example/v1/images/generations",
|
||||
buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
|
||||
)
|
||||
}
|
||||
|
||||
type openAIImageTestSSEEvent struct {
|
||||
Name string
|
||||
Data string
|
||||
@ -371,6 +391,124 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, "gpt-image-2", result.Model)
|
||||
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||
|
||||
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||
require.NoError(t, writer.WriteField("prompt", "replace background"))
|
||||
imagePart, err := writer.CreateFormFile("image", "source.png")
|
||||
require.NoError(t, err)
|
||||
_, err = imagePart.Write([]byte("png-image-content"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req_img_edit_apikey"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-api-key",
|
||||
"base_url": "https://image-upstream.example/v1/",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
|
||||
upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
|
||||
require.Contains(t, string(upstream.lastBody), `name="model"`)
|
||||
require.Contains(t, string(upstream.lastBody), "gpt-image-2")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
for _, field := range openAIChatGPTInternalUnsupportedFields {
|
||||
require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
|
||||
}
|
||||
require.True(t, gjson.GetBytes(normalized, "stream").Bool())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Bool())
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
|
||||
|
||||
normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.False(t, gjson.GetBytes(normalized, "user").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "stream").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "store").Exists())
|
||||
}
|
||||
@ -1366,16 +1366,27 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
|
||||
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled bool,
|
||||
turn int,
|
||||
hasFunctionCallOutput bool,
|
||||
signals ToolContinuationSignals,
|
||||
currentPreviousResponseID string,
|
||||
expectedPreviousResponseID string,
|
||||
) bool {
|
||||
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
|
||||
if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
||||
return false
|
||||
}
|
||||
if signals.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
// If the client already sent the actual tool-call context, treat this as
|
||||
// a full replay / self-contained continuation payload rather than
|
||||
// downgrading it into an inferred delta continuation. item_reference alone
|
||||
// is not enough on the store=false WS path: it still needs a valid prior
|
||||
// response anchor so upstream can resolve the referenced function_call.
|
||||
if signals.HasToolCallContext {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||||
}
|
||||
|
||||
@ -2366,6 +2377,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
|
||||
// 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
|
||||
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
|
||||
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
|
||||
if s.settingService != nil {
|
||||
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
|
||||
ctx = withOpenAIFastPolicyContext(ctx, settings)
|
||||
}
|
||||
}
|
||||
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||
ingressMode := OpenAIWSIngressModeCtxPool
|
||||
@ -2524,6 +2544,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
normalized = next
|
||||
}
|
||||
|
||||
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||
// single integration point for all WS ingress turns (first + follow-up
|
||||
// frames flow through here).
|
||||
//
|
||||
// Model fallback: parseClientPayload above rejects any frame whose
|
||||
// "model" field is missing (line ~2493-2500), so by the time we
|
||||
// reach this point upstreamModel is always derived from a non-empty
|
||||
// per-frame model. The capturedSessionModel fallback used in the
|
||||
// passthrough adapter is therefore not needed in this path.
|
||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||
if policyErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||
}
|
||||
if blocked != nil {
|
||||
// Send a Realtime-style error event to the client first, then
|
||||
// signal the handler to close the connection with PolicyViolation.
|
||||
// We intentionally do NOT forward this frame upstream.
|
||||
//
|
||||
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
|
||||
// the underlying bufio writer before returning (write.go:42 →
|
||||
// 307-311), and the subsequent close handshake re-acquires the
|
||||
// same writeFrameMu, so the error event is guaranteed to reach
|
||||
// the kernel send buffer before any close frame is queued.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes != nil {
|
||||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancel()
|
||||
}
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
blocked.Message,
|
||||
blocked,
|
||||
)
|
||||
}
|
||||
normalized = policyApplied
|
||||
|
||||
return openAIWSClientPayload{
|
||||
payloadRaw: normalized,
|
||||
rawForHash: trimmed,
|
||||
@ -3132,13 +3190,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
skipBeforeTurn = false
|
||||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||||
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||||
toolSignals := ToolContinuationSignals{
|
||||
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
|
||||
}
|
||||
if toolSignals.HasFunctionCallOutput {
|
||||
var currentReqBody map[string]any
|
||||
if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
|
||||
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
|
||||
}
|
||||
}
|
||||
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
|
||||
// store=false + function_call_output 场景必须有续链锚点。
|
||||
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
||||
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled,
|
||||
turn,
|
||||
hasFunctionCallOutput,
|
||||
toolSignals,
|
||||
currentPreviousResponseID,
|
||||
expectedPrev,
|
||||
) {
|
||||
|
||||
@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{captureConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Name: "openai-ingress-tool-context",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||
secondTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{captureConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Name: "openai-ingress-item-reference",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||
secondTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||
|
||||
@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
name string
|
||||
storeDisabled bool
|
||||
turn int
|
||||
hasFunctionCallOutput bool
|
||||
signals ToolContinuationSignals
|
||||
currentPreviousResponse string
|
||||
expectedPrevious string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: false,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_request_already_has_previous_response_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
currentPreviousResponse: "resp_client",
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_tool_call_context_already_present",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "infer_when_only_item_reference_covers_call_ids",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_function_call_output_missing_call_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
tt.storeDisabled,
|
||||
tt.turn,
|
||||
tt.hasFunctionCallOutput,
|
||||
tt.signals,
|
||||
tt.currentPreviousResponse,
|
||||
tt.expectedPrevious,
|
||||
)
|
||||
|
||||
@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||
|
||||
@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
||||
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
||||
// passthrough-relay equivalent of the parseClientPayload integration in the
|
||||
// ingress session path. filter returns:
|
||||
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
||||
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
||||
// event via onBlock and surfaces a transport-level error so the relay
|
||||
// stops reading from the client.
|
||||
// - _, _, err: a transport error other than block.
|
||||
type openAIWSPolicyEnforcingFrameConn struct {
|
||||
inner openaiwsv2.FrameConn
|
||||
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
||||
onBlock func(blocked *OpenAIFastBlockedError)
|
||||
}
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if c == nil || c.inner == nil {
|
||||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||
}
|
||||
msgType, payload, err := c.inner.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
return msgType, payload, err
|
||||
}
|
||||
if c.filter == nil {
|
||||
return msgType, payload, nil
|
||||
}
|
||||
updated, blocked, filterErr := c.filter(msgType, payload)
|
||||
if filterErr != nil {
|
||||
return msgType, payload, filterErr
|
||||
}
|
||||
if blocked != nil {
|
||||
if c.onBlock != nil {
|
||||
c.onBlock(blocked)
|
||||
}
|
||||
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||
}
|
||||
return msgType, updated, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if c == nil || c.inner == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
return c.inner.WriteFrame(ctx, msgType, payload)
|
||||
}
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
||||
if c == nil || c.inner == nil {
|
||||
return nil
|
||||
}
|
||||
return c.inner.Close()
|
||||
}
|
||||
|
||||
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
||||
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
||||
// passthrough WS frame. Mirrors the HTTP-side normalization
|
||||
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
||||
// matches model whitelists identically.
|
||||
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
||||
if account == nil || len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if original == "" {
|
||||
return ""
|
||||
}
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
||||
// derived from a session.update frame's session.model field. Returns "" when
|
||||
// the frame is not a session.update event or carries no session.model. Used
|
||||
// by the per-frame policy filter (client→upstream direction) to keep
|
||||
// capturedSessionModel in sync with the session-level model the client may
|
||||
// rotate mid-session.
|
||||
//
|
||||
// Realtime / Responses WS lets the client change the session model after
|
||||
// the WS handshake via:
|
||||
//
|
||||
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
||||
//
|
||||
// If we only capture the model from the very first frame, a client can ship
|
||||
// gpt-4o on the first response.create (whitelisted as pass), then
|
||||
// session.update to gpt-5.5, then send response.create without "model" so
|
||||
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
||||
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
||||
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
||||
if account == nil || len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
if frameType != "session.update" {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||
if original == "" {
|
||||
return ""
|
||||
}
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
|
||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||
@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
len(firstClientMessage),
|
||||
)
|
||||
|
||||
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
||||
// frames are filtered via a wrapping FrameConn below so every client→
|
||||
// upstream frame goes through the same policy evaluator/normalize/scope as
|
||||
// HTTP entrypoints.
|
||||
//
|
||||
// We capture the session-level model from the first frame here so the
|
||||
// per-frame filter (below) can fall back to it when a follow-up frame
|
||||
// omits "model" — Realtime clients are allowed to send response.create
|
||||
// without re-stating the model, in which case the upstream uses the model
|
||||
// negotiated at session.update time. Without this fallback, an empty
|
||||
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
||||
// silently passed through, defeating the policy on every frame after
|
||||
// the first.
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||
if policyErr != nil {
|
||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||
}
|
||||
if blocked != nil {
|
||||
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
||||
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
||||
// bufio writer before returning (write.go:42 → write.go:307-311).
|
||||
// The subsequent close handshake re-acquires the same writeFrameMu
|
||||
// to send the close frame, so the error event is guaranteed to
|
||||
// reach the kernel send buffer before any close frame is queued.
|
||||
// No explicit flush hop is required here.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes != nil {
|
||||
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancelWrite()
|
||||
}
|
||||
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||
}
|
||||
firstClientMessage = updatedFirst
|
||||
|
||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
||||
//
|
||||
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
||||
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 service_tier。
|
||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||
// 可直接 Store/Load 而无需额外封装。
|
||||
var requestServiceTierPtr atomic.Pointer[string]
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build ws url: %w", err)
|
||||
@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
}
|
||||
|
||||
completedTurns := atomic.Int32{}
|
||||
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
||||
inner: &openAIWSClientFrameConn{conn: clientConn},
|
||||
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
||||
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
||||
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
||||
// 加锁/原子化。
|
||||
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if msgType != coderws.MessageText {
|
||||
return payload, nil, nil
|
||||
}
|
||||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||||
// session.update 修改 session-level model(Realtime /
|
||||
// Responses WS 协议允许),如果不刷新就会出现
|
||||
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
||||
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
||||
// 绕过路径。这里只看 session.update 事件中的 session.model
|
||||
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
// Per-frame model first; if the client omits "model" on a
|
||||
// follow-up frame (legal in Realtime), fall back to the
|
||||
// session-level model captured from the first frame so the
|
||||
// model whitelist still resolves. An empty model would miss
|
||||
// any whitelist and silently fall back to pass.
|
||||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||
if model == "" {
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||
// - 非 response.create 帧(response.cancel /
|
||||
// conversation.item.create / session.update 等)不携带
|
||||
// per-response service_tier,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||
// 上一轮值。
|
||||
// - policyErr != nil:异常路径,保持上一轮值。
|
||||
// - 不带 service_tier 的 response.create 会让
|
||||
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
||||
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
||||
// service_tier 时按 default 处理,billing 应如实反映。
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
},
|
||||
onBlock: func(blocked *OpenAIFastBlockedError) {
|
||||
// See note above on Conn.Write being synchronous w.r.t. flush;
|
||||
// no explicit flush is required to ensure the error event lands
|
||||
// before the close frame.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes == nil {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancel()
|
||||
},
|
||||
}
|
||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||
Ctx: ctx,
|
||||
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
||||
ClientConn: policyClientConn,
|
||||
UpstreamConn: upstreamFrameConn,
|
||||
FirstClientMessage: firstClientMessage,
|
||||
Options: openaiwsv2.RelayOptions{
|
||||
@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
|
||||
)
|
||||
}
|
||||
|
||||
// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
|
||||
// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
|
||||
// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
|
||||
// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
|
||||
//
|
||||
// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE":
|
||||
// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
|
||||
// - 无 WAL 写入、无后续 VACUUM 压力
|
||||
// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
|
||||
func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
|
||||
if days < 0 {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
if days == 0 {
|
||||
return time.Time{}, true, true
|
||||
}
|
||||
return now.AddDate(0, 0, -days), false, true
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
|
||||
out := opsCleanupDeletedCounts{}
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Error-like tables: error logs / retry attempts / alert events.
|
||||
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
|
||||
// runOne 把"truncate? cutoff? batched delete?"封装到一处,
|
||||
// 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
|
||||
runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
|
||||
if truncate {
|
||||
return truncateOpsTable(ctx, s.db, table)
|
||||
}
|
||||
return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
|
||||
}
|
||||
|
||||
// Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
|
||||
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
|
||||
n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.errorLogs = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
|
||||
n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.retryAttempts = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
|
||||
n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.alertEvents = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
|
||||
n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.systemLogs = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
|
||||
n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
||||
}
|
||||
|
||||
// Minute-level metrics snapshots.
|
||||
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
|
||||
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
|
||||
n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
|
||||
}
|
||||
|
||||
// Pre-aggregation tables (hourly/daily).
|
||||
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
|
||||
if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
|
||||
n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.hourlyPreagg = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
|
||||
n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
|
||||
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
||||
if err != nil {
|
||||
// If ops tables aren't present yet (partial deployments), treat as no-op.
|
||||
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
|
||||
if isMissingRelationError(err) {
|
||||
return total, nil
|
||||
}
|
||||
return total, err
|
||||
@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
|
||||
//
|
||||
// 与 deleteOldRowsByID 的差异:
|
||||
// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
|
||||
// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
|
||||
// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
|
||||
//
|
||||
// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。
|
||||
func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
|
||||
if isMissingRelationError(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("count %s: %w", table, err)
|
||||
}
|
||||
if count == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
|
||||
if isMissingRelationError(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("truncate %s: %w", table, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
|
||||
func isMissingRelationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
|
||||
64
backend/internal/service/ops_cleanup_service_test.go
Normal file
64
backend/internal/service/ops_cleanup_service_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOpsCleanupPlan(t *testing.T) {
|
||||
now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
days int
|
||||
wantOK bool
|
||||
wantTruncate bool
|
||||
wantCutoff time.Time
|
||||
}{
|
||||
{name: "negative skips", days: -1, wantOK: false},
|
||||
{name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
|
||||
{name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
|
||||
if ok != tc.wantOK {
|
||||
t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if truncate != tc.wantTruncate {
|
||||
t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
|
||||
}
|
||||
if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
|
||||
t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMissingRelationError(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{name: "nil is not missing", err: nil, want: false},
|
||||
{name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
|
||||
{name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
|
||||
{name: "non-matching error", err: fakeErr("connection refused"), want: false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := isMissingRelationError(tc.err); got != tc.want {
|
||||
t.Fatalf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeErr string
|
||||
|
||||
func (e fakeErr) Error() string { return string(e) }
|
||||
@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
|
||||
if cfg.DataRetention.CleanupSchedule == "" {
|
||||
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
|
||||
}
|
||||
if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
|
||||
// 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
|
||||
// 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
|
||||
if cfg.DataRetention.ErrorLogRetentionDays < 0 {
|
||||
cfg.DataRetention.ErrorLogRetentionDays = 30
|
||||
}
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
|
||||
cfg.DataRetention.MinuteMetricsRetentionDays = 30
|
||||
}
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
|
||||
cfg.DataRetention.HourlyMetricsRetentionDays = 30
|
||||
}
|
||||
// Normalize auto refresh interval (default 30 seconds)
|
||||
@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
|
||||
if cfg == nil {
|
||||
return errors.New("invalid config")
|
||||
}
|
||||
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
|
||||
return errors.New("error_log_retention_days must be between 1 and 365")
|
||||
// 保留天数:0 表示每次清理全部,1-365 表示按天数保留。
|
||||
if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
|
||||
return errors.New("error_log_retention_days must be between 0 and 365")
|
||||
}
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
|
||||
return errors.New("minute_metrics_retention_days must be between 1 and 365")
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
|
||||
return errors.New("minute_metrics_retention_days must be between 0 and 365")
|
||||
}
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
|
||||
return errors.New("hourly_metrics_retention_days must be between 1 and 365")
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
|
||||
return errors.New("hourly_metrics_retention_days must be between 0 and 365")
|
||||
}
|
||||
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
|
||||
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
|
||||
|
||||
@ -269,7 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
|
||||
switch action {
|
||||
case redeemActionSkipCompleted:
|
||||
s.applyAffiliateRebateForOrder(ctx, o)
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
return err
|
||||
}
|
||||
// Code already created and redeemed — just mark completed
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
case redeemActionCreate:
|
||||
@ -283,7 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
|
||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
||||
return fmt.Errorf("redeem balance: %w", err)
|
||||
}
|
||||
s.applyAffiliateRebateForOrder(ctx, o)
|
||||
if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
}
|
||||
|
||||
@ -361,12 +365,12 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
|
||||
return c > 0
|
||||
}
|
||||
|
||||
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) {
|
||||
func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
if s.affiliateService == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
@ -374,7 +378,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
|
||||
})
|
||||
return
|
||||
return fmt.Errorf("begin affiliate rebate tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
@ -384,10 +388,10 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
return fmt.Errorf("claim affiliate rebate audit: %w", err)
|
||||
}
|
||||
if !claimed {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
|
||||
@ -395,7 +399,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
return fmt.Errorf("accrue affiliate rebate: %w", err)
|
||||
}
|
||||
|
||||
if rebateAmount <= 0 {
|
||||
@ -406,14 +410,15 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
|
||||
})
|
||||
return fmt.Errorf("commit affiliate rebate tx: %w", err)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
|
||||
@ -423,14 +428,16 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
return fmt.Errorf("update affiliate rebate applied audit: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
|
||||
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
|
||||
})
|
||||
return fmt.Errorf("commit affiliate rebate tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
|
||||
@ -444,11 +451,11 @@ func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, clien
|
||||
})
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
|
||||
SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW()
|
||||
SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM payment_audit_logs
|
||||
WHERE order_id = $1
|
||||
WHERE order_id = $1::text
|
||||
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
|
||||
)
|
||||
ON CONFLICT (order_id, action) DO NOTHING
|
||||
|
||||
@ -59,6 +59,8 @@ type SchedulerCache interface {
|
||||
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
// TryLockBucket 尝试获取分桶重建锁。
|
||||
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
||||
// UnlockBucket 释放分桶重建锁。
|
||||
UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
|
||||
// ListBuckets 返回已注册的分桶集合。
|
||||
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
||||
// GetOutboxWatermark 读取 outbox 水位。
|
||||
|
||||
@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
_ = s.cache.UnlockBucket(ctx, bucket)
|
||||
}()
|
||||
|
||||
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second
|
||||
|
||||
// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL)
|
||||
type cachedGatewayForwardingSettings struct {
|
||||
fingerprintUnification bool
|
||||
metadataPassthrough bool
|
||||
cchSigning bool
|
||||
expiresAt int64 // unix nano
|
||||
fingerprintUnification bool
|
||||
metadataPassthrough bool
|
||||
cchSigning bool
|
||||
anthropicCacheTTL1hInjection bool
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings
|
||||
@ -1175,6 +1176,24 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
|
||||
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
|
||||
if settings.AffiliateRebateFreezeHours < 0 {
|
||||
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
|
||||
settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
|
||||
if settings.AffiliateRebateDurationDays < 0 {
|
||||
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
|
||||
settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
|
||||
}
|
||||
updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
|
||||
if settings.AffiliateRebatePerInviteeCap < 0 {
|
||||
settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
@ -1227,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
|
||||
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
|
||||
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
|
||||
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||
@ -1287,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
||||
})
|
||||
gatewayForwardingSF.Forget("gateway_forwarding")
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: settings.EnableFingerprintUnification,
|
||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
cchSigning: settings.EnableCCHSigning,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
fingerprintUnification: settings.EnableFingerprintUnification,
|
||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
cchSigning: settings.EnableCCHSigning,
|
||||
anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
})
|
||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||
@ -1397,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
|
||||
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
|
||||
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
|
||||
type gatewayForwardingSettingsResult struct {
|
||||
fp, mp, cch, cacheTTL1h bool
|
||||
}
|
||||
|
||||
func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
|
||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
|
||||
return gatewayForwardingSettingsResult{
|
||||
fp: cached.fingerprintUnification,
|
||||
mp: cached.metadataPassthrough,
|
||||
cch: cached.cchSigning,
|
||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||
}
|
||||
}
|
||||
}
|
||||
type gwfResult struct {
|
||||
fp, mp, cch bool
|
||||
}
|
||||
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
|
||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
|
||||
return gatewayForwardingSettingsResult{
|
||||
fp: cached.fingerprintUnification,
|
||||
mp: cached.metadataPassthrough,
|
||||
cch: cached.cchSigning,
|
||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
|
||||
@ -1421,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
|
||||
SettingKeyEnableFingerprintUnification,
|
||||
SettingKeyEnableMetadataPassthrough,
|
||||
SettingKeyEnableCCHSigning,
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("failed to get gateway forwarding settings", "error", err)
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: true,
|
||||
metadataPassthrough: false,
|
||||
cchSigning: false,
|
||||
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
||||
fingerprintUnification: true,
|
||||
metadataPassthrough: false,
|
||||
cchSigning: false,
|
||||
anthropicCacheTTL1hInjection: false,
|
||||
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
||||
})
|
||||
return gwfResult{true, false, false}, nil
|
||||
return gatewayForwardingSettingsResult{fp: true}, nil
|
||||
}
|
||||
fp := true
|
||||
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
|
||||
@ -1438,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
|
||||
}
|
||||
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
cch := values[SettingKeyEnableCCHSigning] == "true"
|
||||
cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||
fingerprintUnification: fp,
|
||||
metadataPassthrough: mp,
|
||||
cchSigning: cch,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
fingerprintUnification: fp,
|
||||
metadataPassthrough: mp,
|
||||
cchSigning: cch,
|
||||
anthropicCacheTTL1hInjection: cacheTTL1h,
|
||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||
})
|
||||
return gwfResult{fp, mp, cch}, nil
|
||||
return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
|
||||
})
|
||||
if r, ok := val.(gwfResult); ok {
|
||||
return r.fp, r.mp, r.cch
|
||||
if r, ok := val.(gatewayForwardingSettingsResult); ok {
|
||||
return r
|
||||
}
|
||||
return true, false, false // fail-open defaults
|
||||
return gatewayForwardingSettingsResult{fp: true}
|
||||
}
|
||||
|
||||
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
|
||||
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
|
||||
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
|
||||
result := s.getGatewayForwardingSettingsCached(ctx)
|
||||
return result.fp, result.mp, result.cch
|
||||
}
|
||||
|
||||
// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。
|
||||
func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool {
|
||||
return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
@ -1512,6 +1558,54 @@ func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) floa
|
||||
return clampAffiliateRebateRate(rate)
|
||||
}
|
||||
|
||||
// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
|
||||
// 返回 0 表示不冻结(向后兼容)。
|
||||
func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
|
||||
if err != nil {
|
||||
return AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
hours, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || hours < 0 {
|
||||
return AffiliateRebateFreezeHoursDefault
|
||||
}
|
||||
if hours > AffiliateRebateFreezeHoursMax {
|
||||
return AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
return hours
|
||||
}
|
||||
|
||||
// GetAffiliateRebateDurationDays 返回返利有效期(天)。
|
||||
// 返回 0 表示永久有效。
|
||||
func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
|
||||
if err != nil {
|
||||
return AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
days, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || days < 0 {
|
||||
return AffiliateRebateDurationDaysDefault
|
||||
}
|
||||
if days > AffiliateRebateDurationDaysMax {
|
||||
return AffiliateRebateDurationDaysMax
|
||||
}
|
||||
return days
|
||||
}
|
||||
|
||||
// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
|
||||
// 返回 0 表示无上限。
|
||||
func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
|
||||
if err != nil {
|
||||
return AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
|
||||
if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
|
||||
return AffiliateRebatePerInviteeCapDefault
|
||||
}
|
||||
return cap
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
@ -1755,6 +1849,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
|
||||
SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
|
||||
SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
|
||||
SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
|
||||
SettingKeyDefaultUserRPMLimit: "0",
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
@ -1811,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyMaxClaudeCodeVersion: "",
|
||||
|
||||
// 分组隔离(默认不允许未分组 Key 调度)
|
||||
SettingKeyAllowUngroupedKeyScheduling: "false",
|
||||
SettingPaymentVisibleMethodAlipaySource: "",
|
||||
SettingPaymentVisibleMethodWxpaySource: "",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "false",
|
||||
openAIAdvancedSchedulerSettingKey: "false",
|
||||
SettingKeyAllowUngroupedKeyScheduling: "false",
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
||||
SettingPaymentVisibleMethodAlipaySource: "",
|
||||
SettingPaymentVisibleMethodWxpaySource: "",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "false",
|
||||
openAIAdvancedSchedulerSettingKey: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@ -1890,6 +1988,21 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
} else {
|
||||
result.AffiliateRebateRate = AffiliateRebateRateDefault
|
||||
}
|
||||
if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
|
||||
if freezeHours > AffiliateRebateFreezeHoursMax {
|
||||
freezeHours = AffiliateRebateFreezeHoursMax
|
||||
}
|
||||
result.AffiliateRebateFreezeHours = freezeHours
|
||||
}
|
||||
if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
|
||||
if durationDays > AffiliateRebateDurationDaysMax {
|
||||
durationDays = AffiliateRebateDurationDaysMax
|
||||
}
|
||||
result.AffiliateRebateDurationDays = durationDays
|
||||
}
|
||||
if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
|
||||
result.AffiliateRebatePerInviteeCap = perInviteeCap
|
||||
}
|
||||
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
@ -2144,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||
result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||
|
||||
// Web search emulation: quick enabled check from the JSON config
|
||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||
@ -3175,6 +3289,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
||||
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
||||
}
|
||||
|
||||
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
|
||||
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
|
||||
}
|
||||
if value == "" {
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
|
||||
var settings OpenAIFastPolicySettings
|
||||
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
|
||||
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
|
||||
// 行为时定位到 settings 表里的脏数据。
|
||||
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
|
||||
"error", err,
|
||||
"key", SettingKeyOpenAIFastPolicySettings)
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
|
||||
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
validActions := map[string]bool{
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||
}
|
||||
validTiers := map[string]bool{
|
||||
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||
if tier == "" {
|
||||
tier = OpenAIFastTierAny
|
||||
}
|
||||
if !validTiers[tier] {
|
||||
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
|
||||
}
|
||||
settings.Rules[i].ServiceTier = tier
|
||||
if !validActions[rule.Action] {
|
||||
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
|
||||
}
|
||||
if !validScopes[rule.Scope] {
|
||||
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
|
||||
}
|
||||
for j, pattern := range rule.ModelWhitelist {
|
||||
trimmed := strings.TrimSpace(pattern)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
|
||||
}
|
||||
settings.Rules[i].ModelWhitelist[j] = trimmed
|
||||
}
|
||||
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
|
||||
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal openai fast policy settings: %w", err)
|
||||
}
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
|
||||
}
|
||||
|
||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||
if settings == nil {
|
||||
|
||||
@ -104,12 +104,15 @@ type SystemSettings struct {
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
AffiliateEnabled bool
|
||||
AffiliateRebateRate float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
AffiliateEnabled bool
|
||||
AffiliateRebateRate float64
|
||||
AffiliateRebateFreezeHours int
|
||||
AffiliateRebateDurationDays int
|
||||
AffiliateRebatePerInviteeCap float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@ -146,9 +149,10 @@ type SystemSettings struct {
|
||||
BackendModeEnabled bool
|
||||
|
||||
// Gateway forwarding behavior
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||
@ -402,3 +406,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Fast Policy 策略常量
|
||||
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
|
||||
// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
|
||||
// - "flex":低优先级模式
|
||||
// - 省略:normal 默认
|
||||
//
|
||||
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
|
||||
// anthropic-beta header 换成 body 的 service_tier 字段。
|
||||
const (
|
||||
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
|
||||
OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
|
||||
OpenAIFastTierFlex = "flex" // 仅匹配 flex
|
||||
)
|
||||
|
||||
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
|
||||
type OpenAIFastPolicyRule struct {
|
||||
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
|
||||
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
|
||||
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
|
||||
}
|
||||
|
||||
// OpenAIFastPolicySettings OpenAI fast 策略配置
|
||||
type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
|
||||
// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
|
||||
// 让上游按 normal 优先级处理。
|
||||
//
|
||||
// 为什么 ModelWhitelist 为空(=对所有模型生效):
|
||||
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
|
||||
// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
|
||||
// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
|
||||
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
|
||||
// 模型,可在 admin UI 中显式配置 model_whitelist。
|
||||
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
|
||||
return &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{
|
||||
{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ModelWhitelist: []string{},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
345
backend/internal/service/vertex_service_account.go
Normal file
345
backend/internal/service/vertex_service_account.go
Normal file
@ -0,0 +1,345 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
vertexDefaultLocation = "us-central1"
|
||||
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
|
||||
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
vertexServiceAccountCacheSkew = 5 * time.Minute
|
||||
vertexLockWaitTime = 200 * time.Millisecond
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
var (
|
||||
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
|
||||
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
|
||||
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
|
||||
)
|
||||
|
||||
type vertexServiceAccountKey struct {
|
||||
Type string `json:"type"`
|
||||
ProjectID string `json:"project_id"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ClientEmail string `json:"client_email"`
|
||||
TokenURI string `json:"token_uri"`
|
||||
}
|
||||
|
||||
type vertexTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Error string `json:"error"`
|
||||
ErrorDesc string `json:"error_description"`
|
||||
}
|
||||
|
||||
func (a *Account) IsVertexServiceAccount() bool {
|
||||
return a != nil && a.Type == AccountTypeServiceAccount
|
||||
}
|
||||
|
||||
func (a *Account) VertexProjectID() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
key, err := parseVertexServiceAccountKey(a)
|
||||
if err == nil {
|
||||
return strings.TrimSpace(key.ProjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) VertexLocation(model string) string {
|
||||
if a == nil {
|
||||
return vertexDefaultLocation
|
||||
}
|
||||
if model != "" && a.Credentials != nil {
|
||||
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
|
||||
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
|
||||
return strings.TrimSpace(loc)
|
||||
}
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
|
||||
return v
|
||||
}
|
||||
return vertexDefaultLocation
|
||||
}
|
||||
|
||||
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
|
||||
if account == nil || account.Credentials == nil {
|
||||
return nil, errors.New("service account credentials not configured")
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
|
||||
return parseVertexServiceAccountJSON([]byte(raw))
|
||||
}
|
||||
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
|
||||
return parseVertexServiceAccountJSON([]byte(raw))
|
||||
}
|
||||
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
|
||||
b, _ := json.Marshal(nested)
|
||||
return parseVertexServiceAccountJSON(b)
|
||||
}
|
||||
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
|
||||
b, _ := json.Marshal(nested)
|
||||
return parseVertexServiceAccountJSON(b)
|
||||
}
|
||||
return nil, errors.New("service_account_json not found in credentials")
|
||||
}
|
||||
|
||||
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
|
||||
var key vertexServiceAccountKey
|
||||
if err := json.Unmarshal(raw, &key); err != nil {
|
||||
return nil, fmt.Errorf("invalid service account json: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(key.ClientEmail) == "" {
|
||||
return nil, errors.New("service account json missing client_email")
|
||||
}
|
||||
if strings.TrimSpace(key.PrivateKey) == "" {
|
||||
return nil, errors.New("service account json missing private_key")
|
||||
}
|
||||
if strings.TrimSpace(key.ProjectID) == "" {
|
||||
return nil, errors.New("service account json missing project_id")
|
||||
}
|
||||
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
|
||||
key.TokenURI = vertexDefaultTokenURL
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
|
||||
fingerprint := ""
|
||||
if key != nil {
|
||||
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
|
||||
fingerprint = hex.EncodeToString(sum[:8])
|
||||
}
|
||||
if fingerprint == "" && account != nil {
|
||||
fingerprint = fmt.Sprintf("account:%d", account.ID)
|
||||
}
|
||||
return "vertex:service_account:" + fingerprint
|
||||
}
|
||||
|
||||
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
|
||||
// using the shared cache and distributed lock to avoid redundant exchanges.
|
||||
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
|
||||
key, err := parseVertexServiceAccountKey(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cacheKey := vertexServiceAccountCacheKey(account, key)
|
||||
|
||||
if cache != nil {
|
||||
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
locked := false
|
||||
if cache != nil {
|
||||
var lockErr error
|
||||
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(vertexLockWaitTime)
|
||||
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if cache != nil {
|
||||
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": key.ClientEmail,
|
||||
"scope": vertexCloudPlatformScope,
|
||||
"aud": key.TokenURI,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(time.Hour).Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
if strings.TrimSpace(key.PrivateKeyID) != "" {
|
||||
token.Header["kid"] = key.PrivateKeyID
|
||||
}
|
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse service account private key: %w", err)
|
||||
}
|
||||
assertion, err := token.SignedString(privateKey)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
values.Set("assertion", assertion)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var parsed vertexTokenResponse
|
||||
_ = json.Unmarshal(body, &parsed)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := strings.TrimSpace(parsed.ErrorDesc)
|
||||
if msg == "" {
|
||||
msg = strings.TrimSpace(parsed.Error)
|
||||
}
|
||||
if msg == "" {
|
||||
msg = string(bytes.TrimSpace(body))
|
||||
}
|
||||
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
|
||||
}
|
||||
if strings.TrimSpace(parsed.AccessToken) == "" {
|
||||
return "", 0, errors.New("service account token response missing access_token")
|
||||
}
|
||||
ttl := time.Duration(parsed.ExpiresIn) * time.Second
|
||||
if ttl <= 0 {
|
||||
ttl = time.Hour
|
||||
}
|
||||
if ttl > vertexServiceAccountCacheSkew {
|
||||
ttl -= vertexServiceAccountCacheSkew
|
||||
}
|
||||
return parsed.AccessToken, ttl, nil
|
||||
}
|
||||
|
||||
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
location = strings.TrimSpace(location)
|
||||
model = strings.TrimSpace(model)
|
||||
action = strings.TrimSpace(action)
|
||||
if projectID == "" {
|
||||
return "", errors.New("vertex project_id is required")
|
||||
}
|
||||
if location == "" {
|
||||
location = vertexDefaultLocation
|
||||
}
|
||||
if !vertexLocationPattern.MatchString(location) {
|
||||
return "", fmt.Errorf("invalid vertex location: %s", location)
|
||||
}
|
||||
if model == "" {
|
||||
return "", errors.New("vertex model is required")
|
||||
}
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent", "countTokens":
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
|
||||
}
|
||||
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
||||
if location == "global" {
|
||||
host = "aiplatform.googleapis.com"
|
||||
}
|
||||
u := fmt.Sprintf(
|
||||
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
host,
|
||||
url.PathEscape(projectID),
|
||||
url.PathEscape(location),
|
||||
url.PathEscape(model),
|
||||
action,
|
||||
)
|
||||
if stream {
|
||||
u += "?alt=sse"
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
location = strings.TrimSpace(location)
|
||||
model = strings.TrimSpace(model)
|
||||
if projectID == "" {
|
||||
return "", errors.New("vertex project_id is required")
|
||||
}
|
||||
if location == "" {
|
||||
location = vertexDefaultLocation
|
||||
}
|
||||
if !vertexLocationPattern.MatchString(location) {
|
||||
return "", fmt.Errorf("invalid vertex location: %s", location)
|
||||
}
|
||||
if model == "" {
|
||||
return "", errors.New("vertex model is required")
|
||||
}
|
||||
action := "rawPredict"
|
||||
if stream {
|
||||
action = "streamRawPredict"
|
||||
}
|
||||
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
||||
if location == "global" {
|
||||
host = "aiplatform.googleapis.com"
|
||||
}
|
||||
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
|
||||
return fmt.Sprintf(
|
||||
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
host,
|
||||
url.PathEscape(projectID),
|
||||
url.PathEscape(location),
|
||||
escapedModel,
|
||||
action,
|
||||
), nil
|
||||
}
|
||||
|
||||
func normalizeVertexAnthropicModelID(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
|
||||
return model
|
||||
}
|
||||
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
|
||||
return m[1] + "@" + m[2]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
|
||||
}
|
||||
delete(payload, "model")
|
||||
payload["anthropic_version"] = vertexAnthropicVersion
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
77
backend/internal/service/vertex_service_account_test.go
Normal file
77
backend/internal/service/vertex_service_account_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildVertexGeminiURL(t *testing.T) {
|
||||
got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
|
||||
}
|
||||
|
||||
func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
|
||||
got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
|
||||
}
|
||||
|
||||
func TestBuildVertexAnthropicURL(t *testing.T) {
|
||||
got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
|
||||
}
|
||||
|
||||
func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
|
||||
got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
|
||||
}
|
||||
|
||||
func TestNormalizeVertexAnthropicModelID(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
|
||||
require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
|
||||
require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
|
||||
}
|
||||
|
||||
func TestBuildVertexAnthropicRequestBody(t *testing.T) {
|
||||
got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", gjson.GetBytes(got, "model").String())
|
||||
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
|
||||
require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
|
||||
require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
|
||||
}
|
||||
|
||||
func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
|
||||
_, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid vertex location")
|
||||
}
|
||||
|
||||
func TestParseVertexServiceAccountKey(t *testing.T) {
|
||||
raw := `{
|
||||
"type": "service_account",
|
||||
"project_id": "vertex-proj",
|
||||
"private_key_id": "kid",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "svc@vertex-proj.iam.gserviceaccount.com"
|
||||
}`
|
||||
account := &Account{
|
||||
Type: AccountTypeServiceAccount,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"service_account_json": raw,
|
||||
},
|
||||
}
|
||||
key, err := parseVertexServiceAccountKey(account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "vertex-proj", key.ProjectID)
|
||||
require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
|
||||
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
||||
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
||||
}
|
||||
@ -407,12 +407,28 @@ func ProvideBillingCacheService(
|
||||
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
||||
}
|
||||
|
||||
// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
|
||||
func ProvideAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
billingCacheService *BillingCacheService,
|
||||
) *APIKeyService {
|
||||
svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
|
||||
svc.SetRateLimitCacheInvalidator(billingCacheService)
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewAPIKeyService,
|
||||
ProvideAPIKeyService,
|
||||
ProvideAPIKeyAuthCacheInvalidator,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
|
||||
17
backend/migrations/133_affiliate_rebate_freeze.sql
Normal file
17
backend/migrations/133_affiliate_rebate_freeze.sql
Normal file
@ -0,0 +1,17 @@
|
||||
-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
|
||||
ALTER TABLE user_affiliates
|
||||
ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
|
||||
|
||||
COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
|
||||
|
||||
-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
|
||||
-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
|
||||
ALTER TABLE user_affiliate_ledger
|
||||
ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
|
||||
|
||||
COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
|
||||
|
||||
-- 3) Partial index for efficient thaw queries (only rows still frozen).
|
||||
CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
|
||||
ON user_affiliate_ledger (user_id, frozen_until)
|
||||
WHERE frozen_until IS NOT NULL;
|
||||
@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('posts affiliate code when completing linuxdo oauth registration', async () => {
|
||||
const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
|
||||
|
||||
await completeLinuxDoOAuthRegistration(
|
||||
'invite-code',
|
||||
{
|
||||
adoptDisplayName: true,
|
||||
adoptAvatar: false
|
||||
},
|
||||
' AFF123 '
|
||||
)
|
||||
|
||||
expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
|
||||
invitation_code: 'invite-code',
|
||||
aff_code: 'AFF123',
|
||||
adopt_display_name: true,
|
||||
adopt_avatar: false
|
||||
})
|
||||
})
|
||||
|
||||
it('posts oidc invitation completion with adoption decisions', async () => {
|
||||
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
|
||||
|
||||
@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('posts affiliate code when creating pending wechat oauth account', async () => {
|
||||
const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
|
||||
|
||||
await createPendingWeChatOAuthAccount(
|
||||
'invite-code',
|
||||
{
|
||||
adoptDisplayName: false,
|
||||
adoptAvatar: true
|
||||
},
|
||||
'WXAFF'
|
||||
)
|
||||
|
||||
expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
|
||||
invitation_code: 'invite-code',
|
||||
aff_code: 'WXAFF',
|
||||
adopt_display_name: false,
|
||||
adopt_avatar: true
|
||||
})
|
||||
})
|
||||
|
||||
it('classifies oauth completion results as login or bind', async () => {
|
||||
const { getOAuthCompletionKind } = await import('@/api/auth')
|
||||
|
||||
|
||||
@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: {
|
||||
* @returns Success confirmation
|
||||
*/
|
||||
export async function bulkUpdate(
|
||||
accountIds: number[],
|
||||
updates: Record<string, unknown>
|
||||
accountIdsOrPayload: number[] | Record<string, unknown>,
|
||||
updates?: Record<string, unknown>
|
||||
): Promise<{
|
||||
success: number
|
||||
failed: number
|
||||
@ -379,16 +379,19 @@ export async function bulkUpdate(
|
||||
failed_ids?: number[]
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}> {
|
||||
const payload = Array.isArray(accountIdsOrPayload)
|
||||
? {
|
||||
account_ids: accountIdsOrPayload,
|
||||
...(updates ?? {})
|
||||
}
|
||||
: accountIdsOrPayload
|
||||
const { data } = await apiClient.post<{
|
||||
success: number
|
||||
failed: number
|
||||
success_ids?: number[]
|
||||
failed_ids?: number[]
|
||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||
}>('/admin/accounts/bulk-update', {
|
||||
account_ids: accountIds,
|
||||
...updates
|
||||
})
|
||||
}>('/admin/accounts/bulk-update', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
|
||||
@ -309,6 +309,9 @@ export interface SystemSettings {
|
||||
// Default settings
|
||||
default_balance: number;
|
||||
affiliate_rebate_rate: number;
|
||||
affiliate_rebate_freeze_hours: number;
|
||||
affiliate_rebate_duration_days: number;
|
||||
affiliate_rebate_per_invitee_cap: number;
|
||||
default_concurrency: number;
|
||||
default_user_rpm_limit: number;
|
||||
default_subscriptions: DefaultSubscriptionSetting[];
|
||||
@ -437,6 +440,7 @@ export interface SystemSettings {
|
||||
enable_fingerprint_unification: boolean;
|
||||
enable_metadata_passthrough: boolean;
|
||||
enable_cch_signing: boolean;
|
||||
enable_anthropic_cache_ttl_1h_injection: boolean;
|
||||
web_search_emulation_enabled?: boolean;
|
||||
|
||||
// Payment configuration
|
||||
@ -482,6 +486,9 @@ export interface SystemSettings {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled: boolean;
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||
}
|
||||
|
||||
export interface UpdateSettingsRequest {
|
||||
@ -495,6 +502,9 @@ export interface UpdateSettingsRequest {
|
||||
totp_enabled?: boolean; // TOTP 双因素认证
|
||||
default_balance?: number;
|
||||
affiliate_rebate_rate?: number;
|
||||
affiliate_rebate_freeze_hours?: number;
|
||||
affiliate_rebate_duration_days?: number;
|
||||
affiliate_rebate_per_invitee_cap?: number;
|
||||
default_concurrency?: number;
|
||||
default_user_rpm_limit?: number;
|
||||
default_subscriptions?: DefaultSubscriptionSetting[];
|
||||
@ -602,6 +612,7 @@ export interface UpdateSettingsRequest {
|
||||
enable_fingerprint_unification?: boolean;
|
||||
enable_metadata_passthrough?: boolean;
|
||||
enable_cch_signing?: boolean;
|
||||
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
||||
// Payment configuration
|
||||
payment_enabled?: boolean;
|
||||
payment_min_amount?: number;
|
||||
@ -644,6 +655,9 @@ export interface UpdateSettingsRequest {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled?: boolean;
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -871,6 +885,29 @@ export async function updateRectifierSettings(
|
||||
return data;
|
||||
}
|
||||
|
||||
// ==================== OpenAI Fast Policy Settings ====================
|
||||
|
||||
/**
|
||||
* OpenAI fast/flex policy rule interface.
|
||||
* Matches backend dto.OpenAIFastPolicyRule.
|
||||
*/
|
||||
export interface OpenAIFastPolicyRule {
|
||||
service_tier: "all" | "priority" | "flex";
|
||||
action: "pass" | "filter" | "block";
|
||||
scope: "all" | "oauth" | "apikey" | "bedrock";
|
||||
error_message?: string;
|
||||
model_whitelist?: string[];
|
||||
fallback_action?: "pass" | "filter" | "block";
|
||||
fallback_error_message?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI fast/flex policy settings interface.
|
||||
*/
|
||||
export interface OpenAIFastPolicySettings {
|
||||
rules: OpenAIFastPolicyRule[];
|
||||
}
|
||||
|
||||
// ==================== Beta Policy Settings ====================
|
||||
|
||||
/**
|
||||
|
||||
@ -564,9 +564,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise<Rese
|
||||
*/
|
||||
export async function completeLinuxDoOAuthRegistration(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<OAuthTokenResponse> {
|
||||
return createPendingLinuxDoOAuthAccount(invitationCode, decision)
|
||||
return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -576,27 +577,32 @@ export async function completeLinuxDoOAuthRegistration(
|
||||
*/
|
||||
export async function completeOIDCOAuthRegistration(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<OAuthTokenResponse> {
|
||||
return createPendingOIDCOAuthAccount(invitationCode, decision)
|
||||
return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
export async function completeWeChatOAuthRegistration(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<OAuthTokenResponse> {
|
||||
return createPendingWeChatOAuthAccount(invitationCode, decision)
|
||||
return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
async function createPendingOAuthAccount(
|
||||
provider: 'linuxdo' | 'oidc' | 'wechat',
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<PendingOAuthCreateAccountResponse> {
|
||||
const normalizedAffiliateCode = affiliateCode?.trim()
|
||||
const { data } = await apiClient.post<PendingOAuthCreateAccountResponse>(
|
||||
`/auth/oauth/${provider}/complete-registration`,
|
||||
{
|
||||
invitation_code: invitationCode,
|
||||
...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}),
|
||||
...serializeOAuthAdoptionDecision(decision)
|
||||
}
|
||||
)
|
||||
@ -605,23 +611,26 @@ async function createPendingOAuthAccount(
|
||||
|
||||
export async function createPendingLinuxDoOAuthAccount(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<PendingOAuthCreateAccountResponse> {
|
||||
return createPendingOAuthAccount('linuxdo', invitationCode, decision)
|
||||
return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
export async function createPendingOIDCOAuthAccount(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<PendingOAuthCreateAccountResponse> {
|
||||
return createPendingOAuthAccount('oidc', invitationCode, decision)
|
||||
return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
export async function createPendingWeChatOAuthAccount(
|
||||
invitationCode: string,
|
||||
decision?: OAuthAdoptionDecision
|
||||
decision?: OAuthAdoptionDecision,
|
||||
affiliateCode?: string
|
||||
): Promise<PendingOAuthCreateAccountResponse> {
|
||||
return createPendingOAuthAccount('wechat', invitationCode, decision)
|
||||
return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode)
|
||||
}
|
||||
|
||||
export async function completePendingOAuthBindLogin(
|
||||
|
||||
@ -332,6 +332,37 @@
|
||||
|
||||
<!-- Usage data or unlimited flow -->
|
||||
<div class="space-y-1">
|
||||
<div
|
||||
v-if="showGeminiTodayStats && todayStats"
|
||||
class="mb-0.5 flex items-center"
|
||||
>
|
||||
<div class="flex items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400">
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatKeyRequests }} req
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatKeyTokens }}
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800" :title="t('usage.accountBilled')">
|
||||
A ${{ formatKeyCost }}
|
||||
</span>
|
||||
<span
|
||||
v-if="todayStats.user_cost != null"
|
||||
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
|
||||
:title="t('usage.userBilled')"
|
||||
>
|
||||
U ${{ formatKeyUserCost }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-else-if="showGeminiTodayStats && todayStatsLoading"
|
||||
class="mb-0.5 flex items-center gap-1"
|
||||
>
|
||||
<div class="h-3 w-10 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-3 w-8 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-3 w-12 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div v-if="loading" class="space-y-1">
|
||||
<div class="flex items-center gap-1">
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
@ -545,6 +576,10 @@ const shouldFetchUsage = computed(() => {
|
||||
return false
|
||||
})
|
||||
|
||||
const showGeminiTodayStats = computed(() => {
|
||||
return props.account.platform === 'gemini' && props.account.type === 'service_account'
|
||||
})
|
||||
|
||||
const geminiUsageAvailable = computed(() => {
|
||||
return (
|
||||
!!usageInfo.value?.gemini_shared_daily ||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
|
||||
{{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
<svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
|
||||
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@ -227,7 +227,7 @@
|
||||
|
||||
<ModelWhitelistSelector
|
||||
v-model="allowedModels"
|
||||
:platforms="selectedPlatforms"
|
||||
:platforms="targetSelectedPlatforms"
|
||||
/>
|
||||
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
@ -698,6 +698,87 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI OAuth Codex CLI only -->
|
||||
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
id="bulk-edit-openai-codex-cli-only-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-openai-codex-cli-only-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.openai.codexCLIOnly') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="enableCodexCLIOnly"
|
||||
id="bulk-edit-openai-codex-cli-only-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-openai-codex-cli-only"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
id="bulk-edit-openai-codex-cli-only"
|
||||
:class="!enableCodexCLIOnly && 'pointer-events-none opacity-50'"
|
||||
>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.codexCLIOnlyDesc') }}
|
||||
</p>
|
||||
<button
|
||||
id="bulk-edit-openai-codex-cli-only-toggle"
|
||||
type="button"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
codexCLIOnlyEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
@click="codexCLIOnlyEnabled = !codexCLIOnlyEnabled"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
codexCLIOnlyEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI API Key WS mode -->
|
||||
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
id="bulk-edit-openai-apikey-ws-mode-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-openai-apikey-ws-mode-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.openai.wsMode') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="enableOpenAIAPIKeyWSMode"
|
||||
id="bulk-edit-openai-apikey-ws-mode-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-openai-apikey-ws-mode"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
id="bulk-edit-openai-apikey-ws-mode"
|
||||
:class="!enableOpenAIAPIKeyWSMode && 'pointer-events-none opacity-50'"
|
||||
>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||
</p>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t(openAIAPIKeyWSModeConcurrencyHintKey) }}
|
||||
</p>
|
||||
<Select
|
||||
v-model="openaiAPIKeyResponsesWebSocketV2Mode"
|
||||
data-testid="bulk-edit-openai-apikey-ws-mode-select"
|
||||
:options="openAIWSModeOptions"
|
||||
aria-labelledby="bulk-edit-openai-apikey-ws-mode-label"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
|
||||
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@ -933,6 +1014,13 @@ interface Props {
|
||||
accountIds: number[]
|
||||
selectedPlatforms: AccountPlatform[]
|
||||
selectedTypes: AccountType[]
|
||||
target?: {
|
||||
mode: 'selected' | 'filtered'
|
||||
filters?: Record<string, unknown>
|
||||
previewCount?: number
|
||||
selectedPlatforms?: AccountPlatform[]
|
||||
selectedTypes?: AccountType[]
|
||||
}
|
||||
proxies: ProxyConfig[]
|
||||
groups: AdminGroup[]
|
||||
}
|
||||
@ -947,40 +1035,53 @@ const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// Platform awareness
|
||||
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
|
||||
const targetMode = computed(() => props.target?.mode ?? 'selected')
|
||||
const targetPreviewCount = computed(() => props.target?.previewCount ?? props.accountIds.length)
|
||||
const targetSelectedPlatforms = computed(() => props.target?.selectedPlatforms ?? props.selectedPlatforms)
|
||||
const targetSelectedTypes = computed(() => props.target?.selectedTypes ?? props.selectedTypes)
|
||||
const isMixedPlatform = computed(() => targetSelectedPlatforms.value.length > 1)
|
||||
|
||||
const allOpenAIPassthroughCapable = computed(() => {
|
||||
return (
|
||||
props.selectedPlatforms.length === 1 &&
|
||||
props.selectedPlatforms[0] === 'openai' &&
|
||||
props.selectedTypes.length > 0 &&
|
||||
props.selectedTypes.every(t => t === 'oauth' || t === 'apikey')
|
||||
targetSelectedPlatforms.value.length === 1 &&
|
||||
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||
targetSelectedTypes.value.length > 0 &&
|
||||
targetSelectedTypes.value.every(t => t === 'oauth' || t === 'apikey')
|
||||
)
|
||||
})
|
||||
|
||||
const allOpenAIOAuth = computed(() => {
|
||||
return (
|
||||
props.selectedPlatforms.length === 1 &&
|
||||
props.selectedPlatforms[0] === 'openai' &&
|
||||
props.selectedTypes.length > 0 &&
|
||||
props.selectedTypes.every(t => t === 'oauth')
|
||||
targetSelectedPlatforms.value.length === 1 &&
|
||||
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||
targetSelectedTypes.value.length > 0 &&
|
||||
targetSelectedTypes.value.every(t => t === 'oauth')
|
||||
)
|
||||
})
|
||||
|
||||
const allOpenAIAPIKey = computed(() => {
|
||||
return (
|
||||
targetSelectedPlatforms.value.length === 1 &&
|
||||
targetSelectedPlatforms.value[0] === 'openai' &&
|
||||
targetSelectedTypes.value.length > 0 &&
|
||||
targetSelectedTypes.value.every(t => t === 'apikey')
|
||||
)
|
||||
})
|
||||
|
||||
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
|
||||
const allAnthropicOAuthOrSetupToken = computed(() => {
|
||||
return (
|
||||
props.selectedPlatforms.length === 1 &&
|
||||
props.selectedPlatforms[0] === 'anthropic' &&
|
||||
props.selectedTypes.every(t => t === 'oauth' || t === 'setup-token')
|
||||
targetSelectedPlatforms.value.length === 1 &&
|
||||
targetSelectedPlatforms.value[0] === 'anthropic' &&
|
||||
targetSelectedTypes.value.every(t => t === 'oauth' || t === 'setup-token')
|
||||
)
|
||||
})
|
||||
|
||||
const filteredPresets = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return []
|
||||
if (targetSelectedPlatforms.value.length === 0) return []
|
||||
|
||||
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
|
||||
for (const platform of props.selectedPlatforms) {
|
||||
for (const platform of targetSelectedPlatforms.value) {
|
||||
for (const preset of getPresetMappingsByPlatform(platform)) {
|
||||
const key = `${preset.from}=>${preset.to}`
|
||||
if (!dedupedPresets.has(key)) {
|
||||
@ -1012,6 +1113,8 @@ const enableStatus = ref(false)
|
||||
const enableGroups = ref(false)
|
||||
const enableOpenAIPassthrough = ref(false)
|
||||
const enableOpenAIWSMode = ref(false)
|
||||
const enableOpenAIAPIKeyWSMode = ref(false)
|
||||
const enableCodexCLIOnly = ref(false)
|
||||
const enableRpmLimit = ref(false)
|
||||
|
||||
// State - field values
|
||||
@ -1035,6 +1138,8 @@ const status = ref<'active' | 'inactive'>('active')
|
||||
const groupIds = ref<number[]>([])
|
||||
const openaiPassthroughEnabled = ref(false)
|
||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const codexCLIOnlyEnabled = ref(false)
|
||||
const rpmLimitEnabled = ref(false)
|
||||
const bulkBaseRpm = ref<number | null>(null)
|
||||
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
|
||||
@ -1076,6 +1181,9 @@ const openAIWSModeOptions = computed(() => [
|
||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
const openAIAPIKeyWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiAPIKeyResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
|
||||
// Model mapping helpers
|
||||
const addModelMapping = () => {
|
||||
@ -1254,6 +1362,19 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
)
|
||||
}
|
||||
|
||||
if (enableOpenAIAPIKeyWSMode.value) {
|
||||
const extra = ensureExtra()
|
||||
extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
|
||||
openaiAPIKeyResponsesWebSocketV2Mode.value
|
||||
)
|
||||
}
|
||||
|
||||
if (enableCodexCLIOnly.value) {
|
||||
const extra = ensureExtra()
|
||||
extra.codex_cli_only = codexCLIOnlyEnabled.value
|
||||
}
|
||||
|
||||
// RPM limit settings (写入 extra 字段)
|
||||
if (enableRpmLimit.value) {
|
||||
const extra = ensureExtra()
|
||||
@ -1291,8 +1412,8 @@ const mixedChannelConfirmed = ref(false)
|
||||
const canPreCheck = () =>
|
||||
enableGroups.value &&
|
||||
groupIds.value.length > 0 &&
|
||||
props.selectedPlatforms.length === 1 &&
|
||||
(props.selectedPlatforms[0] === 'antigravity' || props.selectedPlatforms[0] === 'anthropic')
|
||||
targetSelectedPlatforms.value.length === 1 &&
|
||||
(targetSelectedPlatforms.value[0] === 'antigravity' || targetSelectedPlatforms.value[0] === 'anthropic')
|
||||
|
||||
const handleClose = () => {
|
||||
showMixedChannelWarning.value = false
|
||||
@ -1309,7 +1430,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
|
||||
|
||||
try {
|
||||
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||
platform: props.selectedPlatforms[0],
|
||||
platform: targetSelectedPlatforms.value[0],
|
||||
group_ids: groupIds.value
|
||||
})
|
||||
if (!result.has_risk) return true
|
||||
@ -1325,7 +1446,7 @@ const preCheckMixedChannelRisk = async (built: Record<string, unknown>): Promise
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (props.accountIds.length === 0) {
|
||||
if (targetMode.value === 'selected' && props.accountIds.length === 0) {
|
||||
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
|
||||
return
|
||||
}
|
||||
@ -1344,6 +1465,8 @@ const handleSubmit = async () => {
|
||||
enableStatus.value ||
|
||||
enableGroups.value ||
|
||||
enableOpenAIWSMode.value ||
|
||||
enableOpenAIAPIKeyWSMode.value ||
|
||||
enableCodexCLIOnly.value ||
|
||||
enableRpmLimit.value ||
|
||||
userMsgQueueMode.value !== null
|
||||
|
||||
@ -1373,7 +1496,12 @@ const submitBulkUpdate = async (baseUpdates: Record<string, unknown>) => {
|
||||
submitting.value = true
|
||||
|
||||
try {
|
||||
const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
|
||||
const res = targetMode.value === 'filtered' && props.target?.filters
|
||||
? await adminAPI.accounts.bulkUpdate({
|
||||
filters: props.target.filters,
|
||||
...updates
|
||||
})
|
||||
: await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
|
||||
const success = res.success || 0
|
||||
const failed = res.failed || 0
|
||||
|
||||
@ -1437,6 +1565,8 @@ watch(
|
||||
enableGroups.value = false
|
||||
enableOpenAIPassthrough.value = false
|
||||
enableOpenAIWSMode.value = false
|
||||
enableOpenAIAPIKeyWSMode.value = false
|
||||
enableCodexCLIOnly.value = false
|
||||
enableRpmLimit.value = false
|
||||
|
||||
// Reset all values
|
||||
@ -1456,6 +1586,8 @@ watch(
|
||||
status.value = 'active'
|
||||
groupIds.value = []
|
||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||
codexCLIOnlyEnabled.value = false
|
||||
rpmLimitEnabled.value = false
|
||||
bulkBaseRpm.value = null
|
||||
bulkRpmStrategy.value = 'tiered'
|
||||
|
||||
@ -166,7 +166,7 @@
|
||||
<!-- Account Type Selection (Anthropic) -->
|
||||
<div v-if="form.platform === 'anthropic'">
|
||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
|
||||
<div class="mt-2 grid grid-cols-2 gap-3 sm:grid-cols-4" data-tour="account-form-type">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
@ -257,6 +257,39 @@
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'service_account'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'service_account'
|
||||
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
|
||||
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'service_account'
|
||||
? 'bg-sky-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">Vertex</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">Service Account</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="accountCategory === 'service_account'"
|
||||
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
|
||||
>
|
||||
<p>{{ t('admin.accounts.vertexAnthropicHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -315,6 +348,7 @@
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.responsesApi') }}</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -333,7 +367,7 @@
|
||||
{{ t('admin.accounts.gemini.helpButton') }}
|
||||
</button>
|
||||
</div>
|
||||
<div class="mt-2 grid grid-cols-2 gap-3" data-tour="account-form-type">
|
||||
<div class="mt-2 grid grid-cols-3 gap-3" data-tour="account-form-type">
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'oauth-based'"
|
||||
@ -405,6 +439,36 @@
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'service_account'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'service_account'
|
||||
? 'border-sky-500 bg-sky-50 dark:bg-sky-900/20'
|
||||
: 'border-gray-200 hover:border-sky-300 dark:border-dark-600 dark:hover:border-sky-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'service_account'
|
||||
? 'bg-sky-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="cloud" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
||||
Vertex
|
||||
</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
||||
Service Account
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
@ -424,6 +488,13 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="accountCategory === 'service_account'"
|
||||
class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200"
|
||||
>
|
||||
<p>{{ t('admin.accounts.vertexGeminiHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- OAuth Type Selection (only show when oauth-based is selected) -->
|
||||
<div v-if="accountCategory === 'oauth-based'" class="mt-4">
|
||||
<label class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</label>
|
||||
@ -623,7 +694,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Tier selection (used as fallback when auto-detection is unavailable/fails) -->
|
||||
<div class="mt-4">
|
||||
<div v-if="accountCategory !== 'service_account'" class="mt-4">
|
||||
<label class="input-label">{{ t('admin.accounts.gemini.tier.label') }}</label>
|
||||
<div class="mt-2">
|
||||
<select
|
||||
@ -872,6 +943,96 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Vertex Service Account -->
|
||||
<div v-if="(form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory === 'service_account'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">Service Account JSON</label>
|
||||
<input
|
||||
ref="vertexServiceAccountFileInput"
|
||||
type="file"
|
||||
accept="application/json,.json"
|
||||
class="hidden"
|
||||
@change="handleVertexServiceAccountFile"
|
||||
/>
|
||||
<div
|
||||
:class="[
|
||||
'rounded-lg border-2 border-dashed px-4 py-5 transition-colors',
|
||||
vertexServiceAccountDragActive
|
||||
? 'border-sky-500 bg-sky-50 dark:border-sky-500 dark:bg-sky-900/20'
|
||||
: 'border-gray-300 bg-gray-50 hover:border-sky-400 hover:bg-sky-50/60 dark:border-dark-500 dark:bg-dark-700/40 dark:hover:border-sky-600 dark:hover:bg-sky-900/10'
|
||||
]"
|
||||
@dragenter.prevent="vertexServiceAccountDragActive = true"
|
||||
@dragover.prevent="vertexServiceAccountDragActive = true"
|
||||
@dragleave.prevent="vertexServiceAccountDragActive = false"
|
||||
@drop.prevent="handleVertexServiceAccountDrop"
|
||||
>
|
||||
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
|
||||
<div class="min-w-0">
|
||||
<div class="flex items-center gap-2 text-sm font-medium text-gray-900 dark:text-white">
|
||||
<Icon name="upload" size="sm" />
|
||||
<span>{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}</span>
|
||||
</div>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-secondary shrink-0"
|
||||
@click="vertexServiceAccountFileInput?.click()"
|
||||
>
|
||||
<Icon name="upload" size="sm" />
|
||||
{{ t('admin.accounts.vertexSaJsonSelectBtn') }}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-if="vertexClientEmail"
|
||||
class="mt-3 rounded-md border border-sky-200 bg-white px-3 py-2 text-xs text-sky-900 dark:border-sky-800/50 dark:bg-dark-800 dark:text-sky-200"
|
||||
>
|
||||
<div class="truncate">Project ID: <span class="font-mono">{{ vertexProjectId }}</span></div>
|
||||
<div class="truncate">Client Email: <span class="font-mono">{{ vertexClientEmail }}</span></div>
|
||||
</div>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonUploadHint') }}</p>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="input-label">Project ID</label>
|
||||
<input
|
||||
v-model="vertexProjectId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
readonly
|
||||
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">Location</label>
|
||||
<select
|
||||
v-model="vertexLocation"
|
||||
required
|
||||
class="input font-mono"
|
||||
>
|
||||
<optgroup
|
||||
v-for="group in VERTEX_LOCATION_OPTIONS"
|
||||
:key="group.label"
|
||||
:label="group.label"
|
||||
>
|
||||
<option
|
||||
v-for="option in group.options"
|
||||
:key="option.value"
|
||||
:value="option.value"
|
||||
>
|
||||
{{ option.label }}
|
||||
</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Antigravity model restriction (applies to OAuth + Upstream) -->
|
||||
<!-- Antigravity 只支持模型映射模式,不支持白名单模式 -->
|
||||
<div v-if="form.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
@ -3119,6 +3280,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
|
||||
import {
|
||||
OPENAI_WS_MODE_CTX_POOL,
|
||||
OPENAI_WS_MODE_OFF,
|
||||
@ -3233,7 +3395,7 @@ interface TempUnschedRuleForm {
|
||||
// State
|
||||
const step = ref(1)
|
||||
const submitting = ref(false)
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category
|
||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
@ -3306,6 +3468,12 @@ const bedrockSessionToken = ref('')
|
||||
const bedrockRegion = ref('us-east-1')
|
||||
const bedrockForceGlobal = ref(false)
|
||||
const bedrockApiKeyValue = ref('')
|
||||
const vertexServiceAccountFileInput = ref<HTMLInputElement | null>(null)
|
||||
const vertexServiceAccountJson = ref('')
|
||||
const vertexProjectId = ref('')
|
||||
const vertexClientEmail = ref('')
|
||||
const vertexLocation = ref('global')
|
||||
const vertexServiceAccountDragActive = ref(false)
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
@ -3556,7 +3724,7 @@ watch(
|
||||
|
||||
// Sync form.type based on accountCategory, addMethod, and platform-specific type
|
||||
watch(
|
||||
[accountCategory, addMethod, antigravityAccountType],
|
||||
[accountCategory, addMethod, antigravityAccountType, () => form.platform],
|
||||
([category, method, agType]) => {
|
||||
// Antigravity upstream 类型(实际创建为 apikey)
|
||||
if (form.platform === 'antigravity' && agType === 'upstream') {
|
||||
@ -3568,7 +3736,9 @@ watch(
|
||||
form.type = 'bedrock' as AccountType
|
||||
return
|
||||
}
|
||||
if (category === 'oauth-based') {
|
||||
if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') {
|
||||
form.type = 'service_account' as AccountType
|
||||
} else if (category === 'oauth-based') {
|
||||
form.type = method as AccountType // 'oauth' or 'setup-token'
|
||||
} else {
|
||||
form.type = 'apikey'
|
||||
@ -3606,6 +3776,12 @@ watch(
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
}
|
||||
if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
|
||||
accountCategory.value = 'oauth-based'
|
||||
}
|
||||
if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') {
|
||||
accountCategory.value = 'oauth-based'
|
||||
}
|
||||
// Reset Bedrock fields when switching platforms
|
||||
bedrockAccessKeyId.value = ''
|
||||
bedrockSecretAccessKey.value = ''
|
||||
@ -3614,6 +3790,10 @@ watch(
|
||||
bedrockForceGlobal.value = false
|
||||
bedrockAuthMode.value = 'sigv4'
|
||||
bedrockApiKeyValue.value = ''
|
||||
vertexServiceAccountJson.value = ''
|
||||
vertexProjectId.value = ''
|
||||
vertexClientEmail.value = ''
|
||||
vertexLocation.value = 'global'
|
||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||
interceptWarmupRequests.value = false
|
||||
@ -4068,6 +4248,10 @@ const resetForm = () => {
|
||||
upstreamType.value = 'sub2api'
|
||||
upstreamBaseUrl.value = ''
|
||||
upstreamApiKey.value = ''
|
||||
vertexServiceAccountJson.value = ''
|
||||
vertexProjectId.value = ''
|
||||
vertexClientEmail.value = ''
|
||||
vertexLocation.value = 'global'
|
||||
tempUnschedEnabled.value = false
|
||||
tempUnschedRules.value = []
|
||||
geminiOAuthType.value = 'code_assist'
|
||||
@ -4195,6 +4379,52 @@ const normalizePoolModeRetryCount = (value: number) => {
|
||||
return normalized
|
||||
}
|
||||
|
||||
const applyVertexServiceAccountJson = (value: string) => {
|
||||
const raw = value.trim()
|
||||
if (!raw) {
|
||||
vertexProjectId.value = ''
|
||||
vertexClientEmail.value = ''
|
||||
return false
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(raw) as Record<string, unknown>
|
||||
const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : ''
|
||||
const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : ''
|
||||
const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : ''
|
||||
if (!projectId || !clientEmail || !privateKey) {
|
||||
appStore.showError(t('admin.accounts.vertexSaJsonMissingFields'))
|
||||
return false
|
||||
}
|
||||
vertexProjectId.value = projectId
|
||||
vertexClientEmail.value = clientEmail
|
||||
vertexServiceAccountJson.value = JSON.stringify(parsed)
|
||||
return true
|
||||
} catch {
|
||||
appStore.showError(t('admin.accounts.vertexSaJsonInvalid'))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value)
|
||||
|
||||
const handleVertexServiceAccountFile = async (event: Event) => {
|
||||
const input = event.target as HTMLInputElement
|
||||
const file = input.files?.[0]
|
||||
if (!file) return
|
||||
try {
|
||||
applyVertexServiceAccountJson(await file.text())
|
||||
} finally {
|
||||
input.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
const handleVertexServiceAccountDrop = async (event: DragEvent) => {
|
||||
vertexServiceAccountDragActive.value = false
|
||||
const file = event.dataTransfer?.files?.[0]
|
||||
if (!file) return
|
||||
applyVertexServiceAccountJson(await file.text())
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
// For OAuth-based type, handle OAuth flow (goes to step 2)
|
||||
if (isOAuthFlow.value) {
|
||||
@ -4348,6 +4578,29 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!parseVertexServiceAccountJson()) {
|
||||
return
|
||||
}
|
||||
if (!vertexLocation.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.vertexLocationRequired'))
|
||||
return
|
||||
}
|
||||
const credentials: Record<string, unknown> = {
|
||||
service_account_json: vertexServiceAccountJson.value.trim(),
|
||||
project_id: vertexProjectId.value.trim(),
|
||||
client_email: vertexClientEmail.value.trim(),
|
||||
location: vertexLocation.value.trim(),
|
||||
tier_id: 'vertex'
|
||||
}
|
||||
await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For apikey type, create directly
|
||||
if (!apiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
|
||||
|
||||
@ -599,6 +599,221 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Vertex Service Account -->
|
||||
<div v-if="(account.platform === 'gemini' || account.platform === 'anthropic') && account.type === 'service_account'" class="space-y-4">
|
||||
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="input-label">Project ID</label>
|
||||
<input
|
||||
v-model="editVertexProjectId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
readonly
|
||||
:placeholder="t('admin.accounts.vertexProjectIdPlaceholder')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.vertexSaJsonEditHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">Location</label>
|
||||
<select
|
||||
v-model="editVertexLocation"
|
||||
required
|
||||
class="input font-mono"
|
||||
>
|
||||
<optgroup
|
||||
v-for="group in VERTEX_LOCATION_OPTIONS"
|
||||
:key="group.label"
|
||||
:label="group.label"
|
||||
>
|
||||
<option
|
||||
v-for="option in group.options"
|
||||
:key="option.value"
|
||||
:value="option.value"
|
||||
>
|
||||
{{ option.label }}
|
||||
</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.vertexLocationHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section for Service Account -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
<svg
|
||||
class="mr-1.5 inline h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
<svg
|
||||
class="mr-1.5 inline h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M8 7h12m0 0l-4-4m4 4l-4 4m0 6H4m0 0l4 4m-4-4l4-4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{
|
||||
t('admin.accounts.supportsAllModels')
|
||||
}}</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else>
|
||||
<div class="mb-3 rounded-lg bg-purple-50 p-3 dark:bg-purple-900/20">
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">
|
||||
<svg
|
||||
class="mr-1 inline h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.mapRequestModels') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Mapping List -->
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
v-model="mapping.from"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.requestModel')"
|
||||
/>
|
||||
<svg
|
||||
class="h-4 w-4 flex-shrink-0 text-gray-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M14 5l7 7m0 0l-7 7m7-7H3"
|
||||
/>
|
||||
</svg>
|
||||
<input
|
||||
v-model="mapping.to"
|
||||
type="text"
|
||||
class="input flex-1"
|
||||
:placeholder="t('admin.accounts.actualModel')"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeModelMapping(index)"
|
||||
class="rounded-lg p-2 text-red-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="addModelMapping"
|
||||
class="mb-3 w-full rounded-lg border-2 border-dashed border-gray-300 px-4 py-2 text-gray-600 transition-colors hover:border-gray-400 hover:text-gray-700 dark:border-dark-500 dark:text-gray-400 dark:hover:border-dark-400 dark:hover:text-gray-300"
|
||||
>
|
||||
<svg
|
||||
class="mr-1 inline h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M12 4v16m8-8H4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
|
||||
<!-- Quick Add Buttons -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in presetMappings"
|
||||
:key="preset.label"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
|
||||
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
||||
<!-- SigV4 fields -->
|
||||
@ -1951,6 +2166,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
|
||||
import {
|
||||
OPENAI_WS_MODE_CTX_POOL,
|
||||
OPENAI_WS_MODE_OFF,
|
||||
@ -2020,6 +2236,9 @@ const editBedrockSessionToken = ref('')
|
||||
const editBedrockRegion = ref('')
|
||||
const editBedrockForceGlobal = ref(false)
|
||||
const editBedrockApiKeyValue = ref('')
|
||||
const editVertexProjectId = ref('')
|
||||
const editVertexClientEmail = ref('')
|
||||
const editVertexLocation = ref('us-central1')
|
||||
const isBedrockAPIKeyMode = computed(() =>
|
||||
props.account?.type === 'bedrock' &&
|
||||
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
|
||||
@ -2279,6 +2498,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
||||
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
|
||||
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
|
||||
editVertexProjectId.value = ''
|
||||
editVertexClientEmail.value = ''
|
||||
editVertexLocation.value = 'us-central1'
|
||||
|
||||
// Load mixed scheduling setting (only for antigravity accounts)
|
||||
mixedScheduling.value = false
|
||||
@ -2504,6 +2726,31 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
||||
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editBaseUrl.value = (credentials.base_url as string) || ''
|
||||
} else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editVertexProjectId.value = (credentials.project_id as string) || ''
|
||||
editVertexClientEmail.value = (credentials.client_email as string) || ''
|
||||
editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
|
||||
|
||||
// Load model mappings for service_account
|
||||
const existingMappings = credentials.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
const platformDefaultUrl =
|
||||
newAccount.platform === 'openai'
|
||||
@ -3099,6 +3346,46 @@ const handleSubmit = async () => {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
if (!editVertexProjectId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId'))
|
||||
return
|
||||
}
|
||||
if (!editVertexClientEmail.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail'))
|
||||
return
|
||||
}
|
||||
if (!editVertexLocation.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.vertexLocationRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
|
||||
appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
|
||||
return
|
||||
}
|
||||
newCredentials.project_id = editVertexProjectId.value.trim()
|
||||
newCredentials.client_email = editVertexClientEmail.value.trim()
|
||||
newCredentials.location = editVertexLocation.value.trim()
|
||||
newCredentials.tier_id = 'vertex'
|
||||
|
||||
// Add model mapping if configured
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if (props.account.type === 'bedrock') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
|
||||
@ -57,6 +57,19 @@ function makeAccount(overrides: Partial<Account>): Account {
|
||||
describe('AccountUsageCell', () => {
|
||||
beforeEach(() => {
|
||||
getUsage.mockReset()
|
||||
Object.defineProperty(window, 'matchMedia', {
|
||||
writable: true,
|
||||
value: vi.fn().mockImplementation(() => ({
|
||||
matches: true,
|
||||
media: '(min-width: 768px)',
|
||||
onchange: null,
|
||||
addListener: vi.fn(),
|
||||
removeListener: vi.fn(),
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
dispatchEvent: vi.fn(),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
|
||||
@ -603,4 +616,43 @@ describe('AccountUsageCell', () => {
|
||||
|
||||
expect(wrapper.text().trim()).toBe('-')
|
||||
})
|
||||
|
||||
it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => {
|
||||
const wrapper = mount(AccountUsageCell, {
|
||||
props: {
|
||||
account: makeAccount({
|
||||
id: 4001,
|
||||
platform: 'gemini',
|
||||
type: 'service_account',
|
||||
credentials: {
|
||||
tier_id: 'vertex',
|
||||
project_id: 'vertex-proj',
|
||||
client_email: 'svc@vertex-proj.iam.gserviceaccount.com',
|
||||
location: 'global'
|
||||
},
|
||||
extra: {}
|
||||
}),
|
||||
todayStats: {
|
||||
requests: 0,
|
||||
tokens: 0,
|
||||
cost: 0,
|
||||
standard_cost: 0,
|
||||
user_cost: 0
|
||||
}
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
UsageProgressBar: true,
|
||||
AccountQuotaInfo: true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('0 req')
|
||||
expect(wrapper.text()).toContain('0')
|
||||
expect(wrapper.text()).toContain('A $0.00')
|
||||
expect(wrapper.text()).toContain('U $0.00')
|
||||
})
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user