chore: merge upstream v0.1.126 — Airwallex, OpenAI fixes, Antigravity UA config
吸收上游 26 个新 commit: - feat: Airwallex 支付 + 多币种支持 (b23055af) - feat: Antigravity user agent 版本可配置 (a07a0dac) - fix(mimic): 同步 messages 里 tool_use 名称 (f97b8534) - fix: cache_control 改写默认关闭 (9377c967) - fix(openai): 多 tool_use 上下文延续 (87d73236) - fix(openai): 未定价模型零成本记录 (6d69ae87) - fix(openai): WS replay tool 输出延续 (16a31557) - fix(openai): 429 plan type 同步 (c3a14717) - fix(gemini): Vertex token 走 account proxy (2a17c0b2) - fix(ccswitch): codex 模型 import deeplink (65493df9) - fix: 订单详情/支付页 NaN 修复 (ba1c6fa5, 6884b03e) - 系统设置标签导航优化 (18cc4691) 本地解决: - config.go CSP: 合并 Firebase Auth (Windsurf) + Airwallex 域名 - KeysView.vue: 删除死代码(已被 buildCcSwitchImportDeeplink 取代) - ccswitchImport.ts: 补充 windsurf 平台 case - 修复 NewOpsHandler/RegisterGatewayRoutes/SelectAccountWithScheduler 测试签名 保留: - Antigravity newapi 兼容 (ForwardUpstream /v1/messages 透传) - Antigravity 核心(gateway_service, oauth, client, credits_overages 等) - Windsurf 全套 - Claude 网关 + TLS 指纹路由 - 其他本地 feat: P2C 调度 / viewer / context 压缩 / RPM / fallback / health
This commit is contained in:
commit
35c6c2b097
@ -1 +1 @@
|
|||||||
0.1.125
|
0.1.126
|
||||||
|
|||||||
@ -35,7 +35,7 @@ const (
|
|||||||
// - media-src 'self' data: (Firebase plays a tiny silent base64 WAV
|
// - media-src 'self' data: (Firebase plays a tiny silent base64 WAV
|
||||||
// to keep the popup channel alive across
|
// to keep the popup channel alive across
|
||||||
// browser autoplay restrictions)
|
// browser autoplay restrictions)
|
||||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com https://apis.google.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; media-src 'self' data:; frame-src https://challenges.cloudflare.com https://*.stripe.com https://*.firebaseapp.com https://accounts.google.com https://apis.google.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com https://apis.google.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; media-src 'self' data:; frame-src https://challenges.cloudflare.com https://*.stripe.com https://*.firebaseapp.com https://accounts.google.com https://apis.google.com https://checkout.airwallex.com https://checkout-demo.airwallex.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||||
|
|
||||||
// UMQ(用户消息队列)模式常量
|
// UMQ(用户消息队列)模式常量
|
||||||
const (
|
const (
|
||||||
|
|||||||
@ -563,9 +563,7 @@ func normalizeCodexImportEntry(entry codexImportEntry) (*codexImportAccount, err
|
|||||||
}
|
}
|
||||||
if item.IDToken != "" {
|
if item.IDToken != "" {
|
||||||
item.Credentials["id_token"] = item.IDToken
|
item.Credentials["id_token"] = item.IDToken
|
||||||
if err := enrichCodexImportAccountFromJWT(item, item.IDToken, false, now); err != nil {
|
_ = enrichCodexImportAccountFromJWT(item, item.IDToken, false, now)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := enrichCodexImportAccountFromJWT(item, item.AccessToken, true, now); err != nil {
|
if err := enrichCodexImportAccountFromJWT(item, item.AccessToken, true, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||||
h := NewOpsHandler(nil, nil)
|
h := NewOpsHandler(nil, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
|||||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||||
sink := service.NewOpsSystemLogSink(nil)
|
sink := service.NewOpsSystemLogSink(nil)
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||||
h := NewOpsHandler(svc, nil)
|
h := NewOpsHandler(svc, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||||
h := NewOpsHandler(nil, nil)
|
h := NewOpsHandler(nil, nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h = NewOpsHandler(svc, nil)
|
h = NewOpsHandler(svc, nil, nil)
|
||||||
r = newOpsSystemLogTestRouter(h, false)
|
r = newOpsSystemLogTestRouter(h, false)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||||
|
|||||||
@ -225,6 +225,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||||
EnableCCHSigning: settings.EnableCCHSigning,
|
EnableCCHSigning: settings.EnableCCHSigning,
|
||||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||||
|
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
||||||
|
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
||||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||||
@ -511,10 +513,12 @@ type UpdateSettingsRequest struct {
|
|||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
|
|
||||||
// Gateway forwarding behavior
|
// Gateway forwarding behavior
|
||||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||||
|
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
||||||
|
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
||||||
|
|
||||||
// Payment visible method routing
|
// Payment visible method routing
|
||||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||||
@ -1250,6 +1254,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if req.AntigravityUserAgentVersion != nil {
|
||||||
|
normalized := strings.TrimSpace(*req.AntigravityUserAgentVersion)
|
||||||
|
req.AntigravityUserAgentVersion = &normalized
|
||||||
|
if normalized != "" && !semverPattern.MatchString(normalized) {
|
||||||
|
response.Error(c, http.StatusBadRequest, "antigravity_user_agent_version must be empty or a valid semver (e.g. 1.23.2)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
||||||
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
||||||
@ -1415,6 +1427,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return previousSettings.EnableAnthropicCacheTTL1hInjection
|
return previousSettings.EnableAnthropicCacheTTL1hInjection
|
||||||
}(),
|
}(),
|
||||||
|
RewriteMessageCacheControl: func() bool {
|
||||||
|
if req.RewriteMessageCacheControl != nil {
|
||||||
|
return *req.RewriteMessageCacheControl
|
||||||
|
}
|
||||||
|
return previousSettings.RewriteMessageCacheControl
|
||||||
|
}(),
|
||||||
|
AntigravityUserAgentVersion: func() string {
|
||||||
|
if req.AntigravityUserAgentVersion != nil {
|
||||||
|
return *req.AntigravityUserAgentVersion
|
||||||
|
}
|
||||||
|
return previousSettings.AntigravityUserAgentVersion
|
||||||
|
}(),
|
||||||
PaymentVisibleMethodAlipaySource: func() string {
|
PaymentVisibleMethodAlipaySource: func() string {
|
||||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||||
@ -1747,6 +1771,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||||
|
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
||||||
|
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
||||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||||
@ -2143,6 +2169,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
|
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
|
||||||
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
|
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
|
||||||
}
|
}
|
||||||
|
if before.RewriteMessageCacheControl != after.RewriteMessageCacheControl {
|
||||||
|
changed = append(changed, "rewrite_message_cache_control")
|
||||||
|
}
|
||||||
|
if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion {
|
||||||
|
changed = append(changed, "antigravity_user_agent_version")
|
||||||
|
}
|
||||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||||
changed = append(changed, "payment_visible_method_alipay_source")
|
changed = append(changed, "payment_visible_method_alipay_source")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -158,10 +158,12 @@ type SystemSettings struct {
|
|||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
|
|
||||||
// Gateway forwarding behavior
|
// Gateway forwarding behavior
|
||||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||||
|
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
||||||
|
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||||
|
|||||||
@ -459,6 +459,7 @@ type PublicOrderResult struct {
|
|||||||
Amount float64 `json:"amount"`
|
Amount float64 `json:"amount"`
|
||||||
PayAmount float64 `json:"pay_amount"`
|
PayAmount float64 `json:"pay_amount"`
|
||||||
FeeRate float64 `json:"fee_rate"`
|
FeeRate float64 `json:"fee_rate"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
PaymentType string `json:"payment_type"`
|
PaymentType string `json:"payment_type"`
|
||||||
OrderType string `json:"order_type"`
|
OrderType string `json:"order_type"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
@ -481,6 +482,7 @@ func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
|
|||||||
Amount: order.Amount,
|
Amount: order.Amount,
|
||||||
PayAmount: order.PayAmount,
|
PayAmount: order.PayAmount,
|
||||||
FeeRate: order.FeeRate,
|
FeeRate: order.FeeRate,
|
||||||
|
Currency: service.PaymentOrderCurrency(order),
|
||||||
PaymentType: order.PaymentType,
|
PaymentType: order.PaymentType,
|
||||||
OrderType: order.OrderType,
|
OrderType: order.OrderType,
|
||||||
Status: order.Status,
|
Status: order.Status,
|
||||||
@ -554,24 +556,67 @@ func isMobile(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
|
type PaymentOrderResult struct {
|
||||||
if len(orders) == 0 {
|
ID int64 `json:"id"`
|
||||||
return orders
|
UserID int64 `json:"user_id"`
|
||||||
}
|
Amount float64 `json:"amount"`
|
||||||
out := make([]*dbent.PaymentOrder, 0, len(orders))
|
PayAmount float64 `json:"pay_amount"`
|
||||||
|
FeeRate float64 `json:"fee_rate"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
PaymentType string `json:"payment_type"`
|
||||||
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
OrderType string `json:"order_type"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
PaidAt *time.Time `json:"paid_at,omitempty"`
|
||||||
|
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||||
|
RefundAmount float64 `json:"refund_amount"`
|
||||||
|
RefundReason *string `json:"refund_reason,omitempty"`
|
||||||
|
RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
|
||||||
|
RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
|
||||||
|
RefundRequestReason *string `json:"refund_request_reason,omitempty"`
|
||||||
|
PlanID *int64 `json:"plan_id,omitempty"`
|
||||||
|
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []PaymentOrderResult {
|
||||||
|
out := make([]PaymentOrderResult, 0, len(orders))
|
||||||
for _, order := range orders {
|
for _, order := range orders {
|
||||||
out = append(out, sanitizePaymentOrderForResponse(order))
|
if item := sanitizePaymentOrderForResponse(order); item != nil {
|
||||||
|
out = append(out, *item)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
|
func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *PaymentOrderResult {
|
||||||
if order == nil {
|
if order == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cloned := *order
|
return &PaymentOrderResult{
|
||||||
cloned.ProviderSnapshot = nil
|
ID: order.ID,
|
||||||
return &cloned
|
UserID: order.UserID,
|
||||||
|
Amount: order.Amount,
|
||||||
|
PayAmount: order.PayAmount,
|
||||||
|
FeeRate: order.FeeRate,
|
||||||
|
Currency: service.PaymentOrderCurrency(order),
|
||||||
|
PaymentType: order.PaymentType,
|
||||||
|
OutTradeNo: order.OutTradeNo,
|
||||||
|
Status: order.Status,
|
||||||
|
OrderType: order.OrderType,
|
||||||
|
CreatedAt: order.CreatedAt,
|
||||||
|
ExpiresAt: order.ExpiresAt,
|
||||||
|
PaidAt: order.PaidAt,
|
||||||
|
CompletedAt: order.CompletedAt,
|
||||||
|
RefundAmount: order.RefundAmount,
|
||||||
|
RefundReason: order.RefundReason,
|
||||||
|
RefundRequestedAt: order.RefundRequestedAt,
|
||||||
|
RefundRequestedBy: order.RefundRequestedBy,
|
||||||
|
RefundRequestReason: order.RefundRequestReason,
|
||||||
|
PlanID: order.PlanID,
|
||||||
|
ProviderInstanceID: order.ProviderInstanceID,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isWeChatBrowser(c *gin.Context) bool {
|
func isWeChatBrowser(c *gin.Context) bool {
|
||||||
|
|||||||
@ -114,6 +114,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
|||||||
SetExpiresAt(time.Now().Add(time.Hour)).
|
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||||
SetClientIP("127.0.0.1").
|
SetClientIP("127.0.0.1").
|
||||||
SetSrcHost("api.example.com").
|
SetSrcHost("api.example.com").
|
||||||
|
SetProviderSnapshot(map[string]any{"currency": "HKD"}).
|
||||||
Save(context.Background())
|
Save(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -141,6 +142,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
|||||||
Amount float64 `json:"amount"`
|
Amount float64 `json:"amount"`
|
||||||
PayAmount float64 `json:"pay_amount"`
|
PayAmount float64 `json:"pay_amount"`
|
||||||
FeeRate float64 `json:"fee_rate"`
|
FeeRate float64 `json:"fee_rate"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
PaymentType string `json:"payment_type"`
|
PaymentType string `json:"payment_type"`
|
||||||
OrderType string `json:"order_type"`
|
OrderType string `json:"order_type"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
@ -155,6 +157,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
|||||||
require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
|
require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
|
||||||
require.Equal(t, 90.64, resp.Data.PayAmount)
|
require.Equal(t, 90.64, resp.Data.PayAmount)
|
||||||
require.Equal(t, 0.03, resp.Data.FeeRate)
|
require.Equal(t, 0.03, resp.Data.FeeRate)
|
||||||
|
require.Equal(t, "HKD", resp.Data.Currency)
|
||||||
require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
|
require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
|
||||||
require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
|
require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
|
||||||
require.Equal(t, service.OrderStatusPending, resp.Data.Status)
|
require.Equal(t, service.OrderStatusPending, resp.Data.Status)
|
||||||
@ -202,6 +205,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
|||||||
SetPaidAt(time.Now()).
|
SetPaidAt(time.Now()).
|
||||||
SetClientIP("127.0.0.1").
|
SetClientIP("127.0.0.1").
|
||||||
SetSrcHost("api.example.com").
|
SetSrcHost("api.example.com").
|
||||||
|
SetProviderSnapshot(map[string]any{"currency": "USD"}).
|
||||||
Save(context.Background())
|
Save(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -242,6 +246,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
|||||||
require.Equal(t, 100.0, resp.Data["amount"])
|
require.Equal(t, 100.0, resp.Data["amount"])
|
||||||
require.Equal(t, 103.0, resp.Data["pay_amount"])
|
require.Equal(t, 103.0, resp.Data["pay_amount"])
|
||||||
require.Equal(t, 0.03, resp.Data["fee_rate"])
|
require.Equal(t, 0.03, resp.Data["fee_rate"])
|
||||||
|
require.Equal(t, "USD", resp.Data["currency"])
|
||||||
require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
|
require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
|
||||||
require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
|
require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
|
||||||
require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
|
require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -60,6 +61,12 @@ func (h *PaymentWebhookHandler) StripeWebhook(c *gin.Context) {
|
|||||||
h.handleNotify(c, payment.TypeStripe)
|
h.handleNotify(c, payment.TypeStripe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AirwallexWebhook 处理空中云汇 Webhook 事件。
|
||||||
|
// POST /api/v1/payment/webhook/airwallex
|
||||||
|
func (h *PaymentWebhookHandler) AirwallexWebhook(c *gin.Context) {
|
||||||
|
h.handleNotify(c, payment.TypeAirwallex)
|
||||||
|
}
|
||||||
|
|
||||||
// handleNotify is the shared logic for all provider webhook handlers.
|
// handleNotify is the shared logic for all provider webhook handlers.
|
||||||
func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) {
|
func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) {
|
||||||
var rawBody string
|
var rawBody string
|
||||||
@ -146,6 +153,17 @@ func extractOutTradeNo(rawBody, providerKey string) string {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return values.Get("out_trade_no")
|
return values.Get("out_trade_no")
|
||||||
}
|
}
|
||||||
|
case payment.TypeAirwallex:
|
||||||
|
var payload struct {
|
||||||
|
Data struct {
|
||||||
|
Object struct {
|
||||||
|
MerchantOrderID string `json:"merchant_order_id"`
|
||||||
|
} `json:"object"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(rawBody), &payload); err == nil {
|
||||||
|
return strings.TrimSpace(payload.Data.Object.MerchantOrderID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// For other providers (Stripe, Alipay direct, WxPay direct), the registry
|
// For other providers (Stripe, Alipay direct, WxPay direct), the registry
|
||||||
// typically has only one instance, so no instance lookup is needed.
|
// typically has only one instance, so no instance lookup is needed.
|
||||||
@ -183,14 +201,14 @@ const (
|
|||||||
wxpaySuccessMessage = "成功"
|
wxpaySuccessMessage = "成功"
|
||||||
)
|
)
|
||||||
|
|
||||||
// writeSuccessResponse sends the provider-specific success response.
|
// writeSuccessResponse 返回各支付服务商要求的成功响应。
|
||||||
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
|
// 微信支付需要 JSON {"code":"SUCCESS","message":"成功"};
|
||||||
// Stripe expects an empty 200; others accept plain text "success".
|
// Stripe 和空中云汇接受空 200,其它服务商接受纯文本 "success"。
|
||||||
func writeSuccessResponse(c *gin.Context, providerKey string) {
|
func writeSuccessResponse(c *gin.Context, providerKey string) {
|
||||||
switch providerKey {
|
switch providerKey {
|
||||||
case payment.TypeWxpay:
|
case payment.TypeWxpay:
|
||||||
c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage})
|
c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage})
|
||||||
case payment.TypeStripe:
|
case payment.TypeStripe, payment.TypeAirwallex:
|
||||||
c.String(http.StatusOK, "")
|
c.String(http.StatusOK, "")
|
||||||
default:
|
default:
|
||||||
c.String(http.StatusOK, "success")
|
c.String(http.StatusOK, "success")
|
||||||
|
|||||||
@ -47,6 +47,13 @@ func TestWriteSuccessResponse(t *testing.T) {
|
|||||||
wantContentType: "text/plain",
|
wantContentType: "text/plain",
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "airwallex returns empty 200",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
wantContentType: "text/plain",
|
||||||
|
wantBody: "",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "easypay returns plain text success",
|
name: "easypay returns plain text success",
|
||||||
providerKey: "easypay",
|
providerKey: "easypay",
|
||||||
@ -165,6 +172,12 @@ func TestExtractOutTradeNo(t *testing.T) {
|
|||||||
rawBody: "{}",
|
rawBody: "{}",
|
||||||
want: "",
|
want: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "airwallex payment intent payload",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
rawBody: `{"name":"payment_intent.succeeded","data":{"object":{"merchant_order_id":"sub2_awx_123"}}}`,
|
||||||
|
want: "sub2_awx_123",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@ -1,24 +1,9 @@
|
|||||||
package payment
|
package payment
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
)
|
|
||||||
|
|
||||||
const centsPerYuan = 100
|
|
||||||
|
|
||||||
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
|
|
||||||
// Uses shopspring/decimal for precision.
|
|
||||||
func YuanToFen(yuanStr string) (int64, error) {
|
func YuanToFen(yuanStr string) (int64, error) {
|
||||||
d, err := decimal.NewFromString(yuanStr)
|
return AmountToMinorUnit(yuanStr, DefaultPaymentCurrency)
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("invalid amount: %s", yuanStr)
|
|
||||||
}
|
|
||||||
return d.Mul(decimal.NewFromInt(centsPerYuan)).IntPart(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
|
|
||||||
func FenToYuan(fen int64) float64 {
|
func FenToYuan(fen int64) float64 {
|
||||||
return decimal.NewFromInt(fen).Div(decimal.NewFromInt(centsPerYuan)).InexactFloat64()
|
return MinorUnitToAmount(fen, DefaultPaymentCurrency)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -126,3 +126,104 @@ func TestYuanToFenRoundTrip(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPaymentCurrencyHelpers(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
currency string
|
||||||
|
amount string
|
||||||
|
wantMinor int64
|
||||||
|
wantBack float64
|
||||||
|
}{
|
||||||
|
{name: "hkd uses cents", currency: "hkd", amount: "12.34", wantMinor: 1234, wantBack: 12.34},
|
||||||
|
{name: "jpy has no minor unit", currency: "JPY", amount: "12", wantMinor: 12, wantBack: 12},
|
||||||
|
{name: "kwd uses three decimal minor units", currency: "KWD", amount: "12.345", wantMinor: 12345, wantBack: 12.345},
|
||||||
|
{name: "isk uses Stripe legacy two-decimal API amount", currency: "ISK", amount: "12", wantMinor: 1200, wantBack: 12},
|
||||||
|
{name: "ugx uses Stripe legacy two-decimal API amount", currency: "UGX", amount: "12.00", wantMinor: 1200, wantBack: 12},
|
||||||
|
{name: "empty currency defaults to cny", currency: "", amount: "1.23", wantMinor: 123, wantBack: 1.23},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := AmountToMinorUnit(tt.amount, tt.currency)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AmountToMinorUnit(%q, %q) unexpected error: %v", tt.amount, tt.currency, err)
|
||||||
|
}
|
||||||
|
if got != tt.wantMinor {
|
||||||
|
t.Fatalf("AmountToMinorUnit(%q, %q) = %d, want %d", tt.amount, tt.currency, got, tt.wantMinor)
|
||||||
|
}
|
||||||
|
back := MinorUnitToAmount(got, tt.currency)
|
||||||
|
if math.Abs(back-tt.wantBack) > 1e-9 {
|
||||||
|
t.Fatalf("MinorUnitToAmount(%d, %q) = %f, want %f", got, tt.currency, back, tt.wantBack)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatAmountForCurrency(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
currency string
|
||||||
|
amount float64
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{currency: "CNY", amount: 12.3, want: "12.30"},
|
||||||
|
{currency: "JPY", amount: 12, want: "12"},
|
||||||
|
{currency: "KWD", amount: 12.345, want: "12.345"},
|
||||||
|
{currency: "ISK", amount: 12, want: "12"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.currency, func(t *testing.T) {
|
||||||
|
if got := FormatAmountForCurrency(tt.amount, tt.currency); got != tt.want {
|
||||||
|
t.Fatalf("FormatAmountForCurrency(%v, %q) = %q, want %q", tt.amount, tt.currency, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAmountToMinorUnitRejectsUnsupportedPrecision(t *testing.T) {
|
||||||
|
if _, err := AmountToMinorUnit("100.50", "JPY"); err == nil {
|
||||||
|
t.Fatal("expected fractional JPY amount to fail")
|
||||||
|
}
|
||||||
|
if _, err := AmountToMinorUnit("100.50", "ISK"); err == nil {
|
||||||
|
t.Fatal("expected fractional ISK amount to fail")
|
||||||
|
}
|
||||||
|
if _, err := AmountToMinorUnit("100.50", "UGX"); err == nil {
|
||||||
|
t.Fatal("expected fractional UGX amount to fail")
|
||||||
|
}
|
||||||
|
if _, err := AmountToMinorUnit("12.345", "HKD"); err == nil {
|
||||||
|
t.Fatal("expected amount with more than two decimal places to fail")
|
||||||
|
}
|
||||||
|
if _, err := AmountToMinorUnit("12.3456", "KWD"); err == nil {
|
||||||
|
t.Fatal("expected amount with more than three decimal places to fail")
|
||||||
|
}
|
||||||
|
if got, err := AmountToMinorUnit("100.00", "JPY"); err != nil || got != 100 {
|
||||||
|
t.Fatalf("AmountToMinorUnit integer-form JPY = (%d, %v), want (100, nil)", got, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThreeDecimalPaymentCurrencies(t *testing.T) {
|
||||||
|
for _, currency := range []string{"BHD", "IQD", "JOD", "KWD", "LYD", "OMR", "TND"} {
|
||||||
|
t.Run(currency, func(t *testing.T) {
|
||||||
|
got, err := AmountToMinorUnit("12.345", currency)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AmountToMinorUnit(%q, %q) unexpected error: %v", "12.345", currency, err)
|
||||||
|
}
|
||||||
|
if got != 12345 {
|
||||||
|
t.Fatalf("AmountToMinorUnit(%q, %q) = %d, want 12345", "12.345", currency, got)
|
||||||
|
}
|
||||||
|
if back := MinorUnitToAmount(got, currency); math.Abs(back-12.345) > 1e-9 {
|
||||||
|
t.Fatalf("MinorUnitToAmount(%d, %q) = %f, want 12.345", got, currency, back)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePaymentCurrencyRejectsInvalidCodes(t *testing.T) {
|
||||||
|
if _, err := NormalizePaymentCurrency("HK"); err == nil {
|
||||||
|
t.Fatal("expected invalid two-letter currency to fail")
|
||||||
|
}
|
||||||
|
if _, err := NormalizePaymentCurrency("US1"); err == nil {
|
||||||
|
t.Fatal("expected non-letter currency to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
118
backend/internal/payment/currency.go
Normal file
118
backend/internal/payment/currency.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package payment
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultPaymentCurrency = "CNY"
|
||||||
|
|
||||||
|
type paymentCurrencyAmountUnit struct {
|
||||||
|
apiMinorUnit int
|
||||||
|
maxFractionDigits int
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
zeroDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 0, maxFractionDigits: 0}
|
||||||
|
twoDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 2, maxFractionDigits: 2}
|
||||||
|
threeDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 3, maxFractionDigits: 3}
|
||||||
|
stripeLegacyZeroAmount = paymentCurrencyAmountUnit{apiMinorUnit: 2, maxFractionDigits: 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
var paymentCurrencyAmountUnits = map[string]paymentCurrencyAmountUnit{
|
||||||
|
"BIF": zeroDecimalAmountUnit,
|
||||||
|
"CLP": zeroDecimalAmountUnit,
|
||||||
|
"DJF": zeroDecimalAmountUnit,
|
||||||
|
"GNF": zeroDecimalAmountUnit,
|
||||||
|
"JPY": zeroDecimalAmountUnit,
|
||||||
|
"KMF": zeroDecimalAmountUnit,
|
||||||
|
"KRW": zeroDecimalAmountUnit,
|
||||||
|
"MGA": zeroDecimalAmountUnit,
|
||||||
|
"PYG": zeroDecimalAmountUnit,
|
||||||
|
"RWF": zeroDecimalAmountUnit,
|
||||||
|
"VND": zeroDecimalAmountUnit,
|
||||||
|
"VUV": zeroDecimalAmountUnit,
|
||||||
|
"XAF": zeroDecimalAmountUnit,
|
||||||
|
"XOF": zeroDecimalAmountUnit,
|
||||||
|
"XPF": zeroDecimalAmountUnit,
|
||||||
|
"ISK": stripeLegacyZeroAmount,
|
||||||
|
"UGX": stripeLegacyZeroAmount,
|
||||||
|
"BHD": threeDecimalAmountUnit,
|
||||||
|
"IQD": threeDecimalAmountUnit,
|
||||||
|
"JOD": threeDecimalAmountUnit,
|
||||||
|
"KWD": threeDecimalAmountUnit,
|
||||||
|
"LYD": threeDecimalAmountUnit,
|
||||||
|
"OMR": threeDecimalAmountUnit,
|
||||||
|
"TND": threeDecimalAmountUnit,
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizePaymentCurrency(raw string) (string, error) {
|
||||||
|
currency := strings.ToUpper(strings.TrimSpace(raw))
|
||||||
|
if currency == "" {
|
||||||
|
return DefaultPaymentCurrency, nil
|
||||||
|
}
|
||||||
|
if len(currency) != 3 {
|
||||||
|
return "", fmt.Errorf("payment currency must be a 3-letter ISO currency code")
|
||||||
|
}
|
||||||
|
for _, ch := range currency {
|
||||||
|
if ch < 'A' || ch > 'Z' {
|
||||||
|
return "", fmt.Errorf("payment currency must be a 3-letter ISO currency code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return currency, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CurrencyMinorUnit(currency string) int {
|
||||||
|
return paymentCurrencyAmountUnitFor(currency).apiMinorUnit
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrencyMaxFractionDigits 返回支付金额允许展示和输入的小数位数。
|
||||||
|
func CurrencyMaxFractionDigits(currency string) int {
|
||||||
|
return paymentCurrencyAmountUnitFor(currency).maxFractionDigits
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatAmountForCurrency 按币种允许的小数位格式化支付金额。
|
||||||
|
func FormatAmountForCurrency(amount float64, currency string) string {
|
||||||
|
return decimal.NewFromFloat(amount).StringFixed(int32(CurrencyMaxFractionDigits(currency)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func paymentCurrencyAmountUnitFor(currency string) paymentCurrencyAmountUnit {
|
||||||
|
normalized, err := NormalizePaymentCurrency(currency)
|
||||||
|
if err != nil {
|
||||||
|
return twoDecimalAmountUnit
|
||||||
|
}
|
||||||
|
if amountUnit, ok := paymentCurrencyAmountUnits[normalized]; ok {
|
||||||
|
return amountUnit
|
||||||
|
}
|
||||||
|
return twoDecimalAmountUnit
|
||||||
|
}
|
||||||
|
|
||||||
|
func AmountToMinorUnit(amountStr, currency string) (int64, error) {
|
||||||
|
d, err := decimal.NewFromString(strings.TrimSpace(amountStr))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid amount: %s", amountStr)
|
||||||
|
}
|
||||||
|
normalizedCurrency, err := NormalizePaymentCurrency(currency)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
amountUnit := paymentCurrencyAmountUnitFor(normalizedCurrency)
|
||||||
|
precisionFactor := decimal.New(1, int32(amountUnit.maxFractionDigits))
|
||||||
|
scaledForPrecision := d.Mul(precisionFactor)
|
||||||
|
if !scaledForPrecision.Equal(scaledForPrecision.Truncate(0)) {
|
||||||
|
if amountUnit.maxFractionDigits == 0 {
|
||||||
|
return 0, fmt.Errorf("payment amount for %s must be a whole number", normalizedCurrency)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("payment amount for %s must not have more than %d decimal places", normalizedCurrency, amountUnit.maxFractionDigits)
|
||||||
|
}
|
||||||
|
factor := decimal.New(1, int32(amountUnit.apiMinorUnit))
|
||||||
|
minorAmount := d.Mul(factor)
|
||||||
|
return minorAmount.IntPart(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MinorUnitToAmount(value int64, currency string) float64 {
|
||||||
|
factor := decimal.New(1, int32(CurrencyMinorUnit(currency)))
|
||||||
|
return decimal.NewFromInt(value).Div(factor).InexactFloat64()
|
||||||
|
}
|
||||||
@ -4,16 +4,18 @@ import (
|
|||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CalculatePayAmount computes the total pay amount given a recharge amount and
|
|
||||||
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
|
|
||||||
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
|
|
||||||
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
|
|
||||||
func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
|
func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
|
||||||
|
return CalculatePayAmountForCurrency(rechargeAmount, feeRate, DefaultPaymentCurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculatePayAmountForCurrency 按币种精度计算应付金额,手续费向上取整到该币种最小支付单位。
|
||||||
|
func CalculatePayAmountForCurrency(rechargeAmount float64, feeRate float64, currency string) string {
|
||||||
|
fractionDigits := int32(CurrencyMaxFractionDigits(currency))
|
||||||
amount := decimal.NewFromFloat(rechargeAmount)
|
amount := decimal.NewFromFloat(rechargeAmount)
|
||||||
if feeRate <= 0 {
|
if feeRate <= 0 {
|
||||||
return amount.StringFixed(2)
|
return amount.StringFixed(fractionDigits)
|
||||||
}
|
}
|
||||||
rate := decimal.NewFromFloat(feeRate)
|
rate := decimal.NewFromFloat(feeRate)
|
||||||
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(2)
|
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(fractionDigits)
|
||||||
return amount.Add(fee).StringFixed(2)
|
return amount.Add(fee).StringFixed(fractionDigits)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -109,3 +109,55 @@ func TestCalculatePayAmount(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCalculatePayAmountForCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
amount float64
|
||||||
|
feeRate float64
|
||||||
|
currency string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero decimal currency rounds fee up to whole unit",
|
||||||
|
amount: 100,
|
||||||
|
feeRate: 2.5,
|
||||||
|
currency: "JPY",
|
||||||
|
expected: "103",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three decimal currency keeps three decimal places",
|
||||||
|
amount: 12.345,
|
||||||
|
feeRate: 1,
|
||||||
|
currency: "KWD",
|
||||||
|
expected: "12.469",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stripe legacy zero decimal currency displays whole unit",
|
||||||
|
amount: 100,
|
||||||
|
feeRate: 2.5,
|
||||||
|
currency: "ISK",
|
||||||
|
expected: "103",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default currency keeps existing two decimal behavior",
|
||||||
|
amount: 10,
|
||||||
|
feeRate: 3.33,
|
||||||
|
currency: "CNY",
|
||||||
|
expected: "10.34",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := CalculatePayAmountForCurrency(tt.amount, tt.feeRate, tt.currency)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Fatalf("CalculatePayAmountForCurrency(%v, %v, %q) = %q, want %q", tt.amount, tt.feeRate, tt.currency, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
639
backend/internal/payment/provider/airwallex.go
Normal file
639
backend/internal/payment/provider/airwallex.go
Normal file
@ -0,0 +1,639 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
airwallexDemoAPIBase = "https://api-demo.airwallex.com/api/v1"
|
||||||
|
airwallexProdAPIBase = "https://api.airwallex.com/api/v1"
|
||||||
|
airwallexDefaultCountry = "CN"
|
||||||
|
airwallexHTTPTimeout = 15 * time.Second
|
||||||
|
airwallexMaxResponseSize = 1 << 20
|
||||||
|
airwallexMaxErrorSummary = 512
|
||||||
|
airwallexTokenSkew = 2 * time.Minute
|
||||||
|
airwallexWebhookTolerance = 5 * time.Minute
|
||||||
|
|
||||||
|
airwallexEventPaymentSucceeded = "payment_intent.succeeded"
|
||||||
|
airwallexEventPaymentCancelled = "payment_intent.cancelled"
|
||||||
|
|
||||||
|
airwallexPaymentStatusSucceeded = "SUCCEEDED"
|
||||||
|
airwallexPaymentStatusCancelled = "CANCELLED"
|
||||||
|
airwallexRefundStatusReceived = "RECEIVED"
|
||||||
|
airwallexRefundStatusAccepted = "ACCEPTED"
|
||||||
|
airwallexRefundStatusSettled = "SETTLED"
|
||||||
|
airwallexRefundStatusFailed = "FAILED"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Airwallex struct {
|
||||||
|
instanceID string
|
||||||
|
config map[string]string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexTokenState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
token string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var airwallexAccessTokens sync.Map
|
||||||
|
|
||||||
|
func NewAirwallex(instanceID string, config map[string]string) (*Airwallex, error) {
|
||||||
|
for _, k := range []string{"clientId", "apiKey", "webhookSecret", "apiBase"} {
|
||||||
|
if strings.TrimSpace(config[k]) == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex config missing required key: %s", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg := cloneStringMap(config)
|
||||||
|
apiBase, err := normalizeAirwallexAPIBase(cfg["apiBase"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg["apiBase"] = apiBase
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex config currency: %w", err)
|
||||||
|
}
|
||||||
|
cfg["currency"] = currency
|
||||||
|
countryCode, err := normalizeAirwallexCountryCode(cfg["countryCode"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg["countryCode"] = countryCode
|
||||||
|
return &Airwallex{
|
||||||
|
instanceID: instanceID,
|
||||||
|
config: cfg,
|
||||||
|
httpClient: &http.Client{Timeout: airwallexHTTPTimeout},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAirwallexCountryCode(raw string) (string, error) {
|
||||||
|
countryCode := strings.ToUpper(strings.TrimSpace(raw))
|
||||||
|
if countryCode == "" {
|
||||||
|
return airwallexDefaultCountry, nil
|
||||||
|
}
|
||||||
|
if len(countryCode) != 2 {
|
||||||
|
return "", fmt.Errorf("airwallex config countryCode must be a two-letter ISO country code")
|
||||||
|
}
|
||||||
|
for _, ch := range countryCode {
|
||||||
|
if ch < 'A' || ch > 'Z' {
|
||||||
|
return "", fmt.Errorf("airwallex config countryCode must be a two-letter ISO country code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return countryCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAirwallexAPIBase(raw string) (string, error) {
|
||||||
|
base := strings.TrimSpace(raw)
|
||||||
|
if base == "" {
|
||||||
|
return "", fmt.Errorf("airwallex apiBase is required")
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(base)
|
||||||
|
if err != nil || parsed.Scheme != "https" || parsed.Host == "" {
|
||||||
|
return "", fmt.Errorf("airwallex apiBase must be an HTTPS URL")
|
||||||
|
}
|
||||||
|
host := strings.ToLower(parsed.Host)
|
||||||
|
if host != "api-demo.airwallex.com" && host != "api.airwallex.com" {
|
||||||
|
return "", fmt.Errorf("airwallex apiBase host must be api-demo.airwallex.com or api.airwallex.com")
|
||||||
|
}
|
||||||
|
parsed.RawQuery = ""
|
||||||
|
parsed.Fragment = ""
|
||||||
|
parsed.RawPath = ""
|
||||||
|
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||||
|
if parsed.Path == "" {
|
||||||
|
parsed.Path = "/api/v1"
|
||||||
|
}
|
||||||
|
if parsed.Path != "/api/v1" {
|
||||||
|
return "", fmt.Errorf("airwallex apiBase path must be /api/v1")
|
||||||
|
}
|
||||||
|
return parsed.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) Name() string { return "空中云汇" }
|
||||||
|
func (a *Airwallex) ProviderKey() string { return payment.TypeAirwallex }
|
||||||
|
func (a *Airwallex) SupportedTypes() []payment.PaymentType {
|
||||||
|
return []payment.PaymentType{payment.TypeAirwallex}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) MerchantIdentityMetadata() map[string]string {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
metadata := map[string]string{"currency": a.currency()}
|
||||||
|
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
|
||||||
|
metadata["account_id"] = accountID
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) currency() string {
|
||||||
|
if a == nil {
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(a.config["currency"])
|
||||||
|
if err != nil {
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
return currency
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||||
|
amount, err := decimal.NewFromString(req.Amount)
|
||||||
|
if err != nil || amount.LessThanOrEqual(decimal.Zero) {
|
||||||
|
return nil, fmt.Errorf("airwallex create payment: invalid amount %s", req.Amount)
|
||||||
|
}
|
||||||
|
token, err := a.accessToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currency := a.currency()
|
||||||
|
requestID := airwallexDeterministicRequestID("payment-intent", req.OrderID, req.Amount, currency)
|
||||||
|
payload := airwallexCreatePaymentIntentRequest{
|
||||||
|
RequestID: requestID,
|
||||||
|
Amount: newAirwallexRequestAmount(amount),
|
||||||
|
Currency: currency,
|
||||||
|
MerchantOrderID: req.OrderID,
|
||||||
|
ReturnURL: req.ReturnURL,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"order_id": req.OrderID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if descriptor := strings.TrimSpace(a.config["descriptor"]); descriptor != "" {
|
||||||
|
payload.Descriptor = descriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
var intent airwallexPaymentIntent
|
||||||
|
if err := a.doJSON(ctx, http.MethodPost, "/pa/payment_intents/create", token, payload, &intent); err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex create payment: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(intent.ID) == "" || strings.TrimSpace(intent.ClientSecret) == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex create payment: missing payment intent id or client secret")
|
||||||
|
}
|
||||||
|
return &payment.CreatePaymentResponse{
|
||||||
|
TradeNo: intent.ID,
|
||||||
|
ClientSecret: intent.ClientSecret,
|
||||||
|
IntentID: intent.ID,
|
||||||
|
Currency: currency,
|
||||||
|
CountryCode: a.config["countryCode"],
|
||||||
|
PaymentEnv: a.checkoutEnv(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
|
||||||
|
intentID := strings.TrimSpace(tradeNo)
|
||||||
|
if intentID == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex query order: missing payment intent id")
|
||||||
|
}
|
||||||
|
token, err := a.accessToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var intent airwallexPaymentIntent
|
||||||
|
if err := a.doJSON(ctx, http.MethodGet, "/pa/payment_intents/"+url.PathEscape(intentID), token, nil, &intent); err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex query order: %w", err)
|
||||||
|
}
|
||||||
|
return &payment.QueryOrderResponse{
|
||||||
|
TradeNo: intent.ID,
|
||||||
|
Status: airwallexProviderStatus(intent.Status),
|
||||||
|
Amount: intent.Amount.InexactFloat64(),
|
||||||
|
Metadata: a.intentMetadata(intent, ""),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) VerifyNotification(_ context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
|
||||||
|
if err := verifyAirwallexWebhookSignature(rawBody, headers, a.config["webhookSecret"], time.Now()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var event airwallexWebhookEvent
|
||||||
|
if err := json.Unmarshal([]byte(rawBody), &event); err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex parse webhook: %w", err)
|
||||||
|
}
|
||||||
|
switch event.Name {
|
||||||
|
case airwallexEventPaymentSucceeded, airwallexEventPaymentCancelled:
|
||||||
|
default:
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var intent airwallexPaymentIntent
|
||||||
|
if err := json.Unmarshal(event.Data.Object, &intent); err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex parse payment intent: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(intent.ID) == "" || strings.TrimSpace(intent.MerchantOrderID) == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex webhook missing payment intent id or merchant_order_id")
|
||||||
|
}
|
||||||
|
status := payment.ProviderStatusFailed
|
||||||
|
if event.Name == airwallexEventPaymentSucceeded {
|
||||||
|
if strings.ToUpper(strings.TrimSpace(intent.Status)) != airwallexPaymentStatusSucceeded {
|
||||||
|
return nil, fmt.Errorf("airwallex succeeded webhook has non-succeeded status: %s", intent.Status)
|
||||||
|
}
|
||||||
|
status = payment.NotificationStatusSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
return &payment.PaymentNotification{
|
||||||
|
TradeNo: intent.ID,
|
||||||
|
OrderID: intent.MerchantOrderID,
|
||||||
|
Amount: intent.Amount.InexactFloat64(),
|
||||||
|
Status: status,
|
||||||
|
RawData: rawBody,
|
||||||
|
Metadata: a.intentMetadata(intent, event.accountID()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||||
|
intentID := strings.TrimSpace(req.TradeNo)
|
||||||
|
if intentID == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex refund missing payment intent id")
|
||||||
|
}
|
||||||
|
amount, err := decimal.NewFromString(req.Amount)
|
||||||
|
if err != nil || amount.LessThanOrEqual(decimal.Zero) {
|
||||||
|
return nil, fmt.Errorf("airwallex refund: invalid amount %s", req.Amount)
|
||||||
|
}
|
||||||
|
token, err := a.accessToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := airwallexCreateRefundRequest{
|
||||||
|
RequestID: airwallexDeterministicRequestID("refund", intentID, req.Amount),
|
||||||
|
PaymentIntentID: intentID,
|
||||||
|
Amount: newAirwallexRequestAmount(amount),
|
||||||
|
Reason: strings.TrimSpace(req.Reason),
|
||||||
|
}
|
||||||
|
if payload.Reason == "" {
|
||||||
|
payload.Reason = "refund"
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp airwallexRefund
|
||||||
|
if err := a.doJSON(ctx, http.MethodPost, "/pa/refunds/create", token, payload, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("airwallex refund: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(resp.ID) == "" {
|
||||||
|
return nil, fmt.Errorf("airwallex refund: missing refund id")
|
||||||
|
}
|
||||||
|
refundResp := &payment.RefundResponse{
|
||||||
|
RefundID: resp.ID,
|
||||||
|
Status: airwallexRefundProviderStatus(resp.Status),
|
||||||
|
}
|
||||||
|
if refundResp.Status != payment.ProviderStatusSuccess {
|
||||||
|
return refundResp, fmt.Errorf("airwallex refund not settled: status %s", strings.ToUpper(strings.TrimSpace(resp.Status)))
|
||||||
|
}
|
||||||
|
return refundResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) CancelPayment(ctx context.Context, tradeNo string) error {
|
||||||
|
intentID := strings.TrimSpace(tradeNo)
|
||||||
|
if intentID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token, err := a.accessToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("airwallex auth: %w", err)
|
||||||
|
}
|
||||||
|
var intent airwallexPaymentIntent
|
||||||
|
if err := a.doJSON(ctx, http.MethodPost, "/pa/payment_intents/"+url.PathEscape(intentID)+"/cancel", token, nil, &intent); err != nil {
|
||||||
|
return fmt.Errorf("airwallex cancel payment: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) intentMetadata(intent airwallexPaymentIntent, accountID string) map[string]string {
|
||||||
|
metadata := map[string]string{
|
||||||
|
"currency": strings.ToUpper(strings.TrimSpace(intent.Currency)),
|
||||||
|
"status": strings.ToUpper(strings.TrimSpace(intent.Status)),
|
||||||
|
}
|
||||||
|
if accountID = strings.TrimSpace(accountID); accountID != "" {
|
||||||
|
metadata["account_id"] = accountID
|
||||||
|
} else if configured := strings.TrimSpace(a.config["accountId"]); configured != "" {
|
||||||
|
metadata["account_id"] = configured
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) checkoutEnv() string {
|
||||||
|
if strings.EqualFold(a.config["apiBase"], airwallexProdAPIBase) {
|
||||||
|
return "prod"
|
||||||
|
}
|
||||||
|
return "demo"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) accessToken(ctx context.Context) (string, error) {
|
||||||
|
cacheKey := a.tokenCacheKey()
|
||||||
|
rawState, _ := airwallexAccessTokens.LoadOrStore(cacheKey, &airwallexTokenState{})
|
||||||
|
state, ok := rawState.(*airwallexTokenState)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("airwallex auth token cache state type mismatch")
|
||||||
|
}
|
||||||
|
state.mu.Lock()
|
||||||
|
defer state.mu.Unlock()
|
||||||
|
|
||||||
|
if state.token != "" && time.Now().Add(airwallexTokenSkew).Before(state.expiresAt) {
|
||||||
|
return state.token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.config["apiBase"]+"/authentication/login", nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("x-client-id", a.config["clientId"])
|
||||||
|
req.Header.Set("x-api-key", a.config["apiKey"])
|
||||||
|
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
|
||||||
|
req.Header.Set("x-login-as", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, status, err := a.do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if status < http.StatusOK || status >= http.StatusMultipleChoices {
|
||||||
|
return "", formatAirwallexAuthHTTPError(status, body)
|
||||||
|
}
|
||||||
|
var resp airwallexAuthResponse
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return "", fmt.Errorf("parse authentication response: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(resp.Token) == "" {
|
||||||
|
return "", fmt.Errorf("authentication response missing token")
|
||||||
|
}
|
||||||
|
expiresAt, err := parseAirwallexTime(resp.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(25 * time.Minute)
|
||||||
|
}
|
||||||
|
state.token = resp.Token
|
||||||
|
state.expiresAt = expiresAt
|
||||||
|
return state.token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatAirwallexAuthHTTPError(status int, body []byte) error {
|
||||||
|
summary := summarizeAirwallexResponse(body)
|
||||||
|
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||||
|
return fmt.Errorf("authentication HTTP %d: %s; Airwallex credentials were rejected, check Client ID/API Key, API Base environment (sandbox: https://api-demo.airwallex.com/api/v1, production: https://api.airwallex.com/api/v1), and Account ID (leave it empty for single-account scoped keys)", status, summary)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("authentication HTTP %d: %s", status, summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) tokenCacheKey() string {
|
||||||
|
sum := sha256.Sum256([]byte(a.config["apiKey"]))
|
||||||
|
return a.config["apiBase"] + "|" + a.config["clientId"] + "|" + strings.TrimSpace(a.config["accountId"]) + "|" + hex.EncodeToString(sum[:8])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) doJSON(ctx context.Context, method, path, token string, payload any, out any) error {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if payload != nil {
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, a.config["apiBase"]+path, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
|
||||||
|
req.Header.Set("x-on-behalf-of", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, status, err := a.do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if status < http.StatusOK || status >= http.StatusMultipleChoices {
|
||||||
|
return fmt.Errorf("HTTP %d: %s", status, summarizeAirwallexResponse(body))
|
||||||
|
}
|
||||||
|
if out == nil || len(bytes.TrimSpace(body)) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, out); err != nil {
|
||||||
|
return fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Airwallex) do(req *http.Request) ([]byte, int, error) {
|
||||||
|
client := a.httpClient
|
||||||
|
if client == nil {
|
||||||
|
client = &http.Client{Timeout: airwallexHTTPTimeout}
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, airwallexMaxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, err
|
||||||
|
}
|
||||||
|
return body, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func airwallexProviderStatus(status string) string {
|
||||||
|
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||||||
|
case airwallexPaymentStatusSucceeded:
|
||||||
|
return payment.ProviderStatusPaid
|
||||||
|
case airwallexPaymentStatusCancelled:
|
||||||
|
return payment.ProviderStatusFailed
|
||||||
|
default:
|
||||||
|
return payment.ProviderStatusPending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func airwallexRefundProviderStatus(status string) string {
|
||||||
|
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||||||
|
case airwallexRefundStatusSettled:
|
||||||
|
return payment.ProviderStatusSuccess
|
||||||
|
case airwallexRefundStatusFailed:
|
||||||
|
return payment.ProviderStatusFailed
|
||||||
|
case airwallexRefundStatusReceived, airwallexRefundStatusAccepted:
|
||||||
|
return payment.ProviderStatusPending
|
||||||
|
default:
|
||||||
|
return payment.ProviderStatusPending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func airwallexDeterministicRequestID(parts ...string) string {
|
||||||
|
hash := sha256.Sum256([]byte(strings.Join(parts, "\x00")))
|
||||||
|
var id uuid.UUID
|
||||||
|
copy(id[:], hash[:16])
|
||||||
|
id[6] = (id[6] & 0x0f) | 0x40
|
||||||
|
id[8] = (id[8] & 0x3f) | 0x80
|
||||||
|
return id.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyAirwallexWebhookSignature(rawBody string, headers map[string]string, secret string, now time.Time) error {
|
||||||
|
secret = strings.TrimSpace(secret)
|
||||||
|
if secret == "" {
|
||||||
|
return fmt.Errorf("airwallex webhookSecret not configured")
|
||||||
|
}
|
||||||
|
timestamp := strings.TrimSpace(headers["x-timestamp"])
|
||||||
|
signature := strings.ToLower(strings.TrimSpace(headers["x-signature"]))
|
||||||
|
if timestamp == "" || signature == "" {
|
||||||
|
return fmt.Errorf("airwallex notification missing x-timestamp or x-signature header")
|
||||||
|
}
|
||||||
|
|
||||||
|
mac := hmac.New(sha256.New, []byte(secret))
|
||||||
|
_, _ = mac.Write([]byte(timestamp))
|
||||||
|
_, _ = mac.Write([]byte(rawBody))
|
||||||
|
expected := hex.EncodeToString(mac.Sum(nil))
|
||||||
|
if !hmac.Equal([]byte(expected), []byte(signature)) {
|
||||||
|
return fmt.Errorf("airwallex invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := parseAirwallexWebhookTimestamp(timestamp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if now.IsZero() {
|
||||||
|
now = time.Now()
|
||||||
|
}
|
||||||
|
if diff := now.Sub(ts).Abs(); diff > airwallexWebhookTolerance {
|
||||||
|
return fmt.Errorf("airwallex webhook timestamp outside tolerance")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAirwallexWebhookTimestamp(raw string) (time.Time, error) {
|
||||||
|
ts, err := decimal.NewFromString(strings.TrimSpace(raw))
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("airwallex invalid webhook timestamp")
|
||||||
|
}
|
||||||
|
millis := ts.IntPart()
|
||||||
|
if millis <= 0 {
|
||||||
|
return time.Time{}, fmt.Errorf("airwallex invalid webhook timestamp")
|
||||||
|
}
|
||||||
|
return time.UnixMilli(millis), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAirwallexTime(raw string) (time.Time, error) {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return time.Time{}, fmt.Errorf("empty time")
|
||||||
|
}
|
||||||
|
for _, layout := range []string{time.RFC3339, "2006-01-02T15:04:05-0700", "2006-01-02T15:04:05.000-0700"} {
|
||||||
|
if t, err := time.Parse(layout, raw); err == nil {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return time.Time{}, fmt.Errorf("invalid time: %s", raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeAirwallexResponse(body []byte) string {
|
||||||
|
summary := strings.Join(strings.Fields(string(body)), " ")
|
||||||
|
if summary == "" {
|
||||||
|
return "<empty>"
|
||||||
|
}
|
||||||
|
if len(summary) > airwallexMaxErrorSummary {
|
||||||
|
return summary[:airwallexMaxErrorSummary] + "..."
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexAuthResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexCreatePaymentIntentRequest struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Amount airwallexRequestAmount `json:"amount"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
MerchantOrderID string `json:"merchant_order_id"`
|
||||||
|
ReturnURL string `json:"return_url,omitempty"`
|
||||||
|
Descriptor string `json:"descriptor,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexCreateRefundRequest struct {
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
PaymentIntentID string `json:"payment_intent_id"`
|
||||||
|
Amount airwallexRequestAmount `json:"amount,omitempty"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexRequestAmount struct {
|
||||||
|
decimal.Decimal
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAirwallexRequestAmount(amount decimal.Decimal) airwallexRequestAmount {
|
||||||
|
return airwallexRequestAmount{Decimal: amount}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a airwallexRequestAmount) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(a.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *airwallexRequestAmount) UnmarshalJSON(data []byte) error {
|
||||||
|
amount, err := decimal.NewFromString(strings.Trim(string(data), `"`))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.Decimal = amount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexPaymentIntent struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
MerchantOrderID string `json:"merchant_order_id"`
|
||||||
|
Amount decimal.Decimal `json:"amount"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexRefund struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
PaymentIntentID string `json:"payment_intent_id"`
|
||||||
|
Amount decimal.Decimal `json:"amount"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type airwallexWebhookEvent struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
AccountID string `json:"accountId"`
|
||||||
|
AccountIDSnake string `json:"account_id"`
|
||||||
|
Data struct {
|
||||||
|
Object json.RawMessage `json:"object"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e airwallexWebhookEvent) accountID() string {
|
||||||
|
if accountID := strings.TrimSpace(e.AccountID); accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(e.AccountIDSnake)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ payment.Provider = (*Airwallex)(nil)
|
||||||
|
_ payment.CancelableProvider = (*Airwallex)(nil)
|
||||||
|
_ payment.MerchantIdentityProvider = (*Airwallex)(nil)
|
||||||
|
)
|
||||||
352
backend/internal/payment/provider/airwallex_test.go
Normal file
352
backend/internal/payment/provider/airwallex_test.go
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewAirwallexValidatesConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "secret",
|
||||||
|
"apiBase": "https://evil.example.com/api/v1",
|
||||||
|
})
|
||||||
|
require.ErrorContains(t, err, "apiBase host")
|
||||||
|
|
||||||
|
_, err = NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "secret",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
"countryCode": "C1",
|
||||||
|
})
|
||||||
|
require.ErrorContains(t, err, "countryCode")
|
||||||
|
|
||||||
|
prov, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "secret",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, payment.TypeAirwallex, prov.ProviderKey())
|
||||||
|
require.Equal(t, []payment.PaymentType{payment.TypeAirwallex}, prov.SupportedTypes())
|
||||||
|
require.Equal(t, payment.DefaultPaymentCurrency, prov.config["currency"])
|
||||||
|
require.Equal(t, airwallexDefaultCountry, prov.config["countryCode"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexCreatePaymentUsesServerAmountAndStableRequestID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var createRequests []airwallexCreatePaymentIntentRequest
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/authentication/login":
|
||||||
|
require.Equal(t, "cid", r.Header.Get("x-client-id"))
|
||||||
|
require.Equal(t, "key", r.Header.Get("x-api-key"))
|
||||||
|
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/create":
|
||||||
|
require.Equal(t, "Bearer token-1", r.Header.Get("Authorization"))
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, string(body), `"amount":12.34`)
|
||||||
|
var payload airwallexCreatePaymentIntentRequest
|
||||||
|
require.NoError(t, json.Unmarshal(body, &payload))
|
||||||
|
createRequests = append(createRequests, payload)
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov := mustTestAirwallexProvider(t, server)
|
||||||
|
resp, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
|
||||||
|
OrderID: "sub2_order",
|
||||||
|
Amount: "12.34",
|
||||||
|
ReturnURL: "https://merchant.example.com/payment/result",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "int_123", resp.TradeNo)
|
||||||
|
require.Equal(t, "secret_123", resp.ClientSecret)
|
||||||
|
require.Equal(t, "int_123", resp.IntentID)
|
||||||
|
require.Equal(t, "CNY", resp.Currency)
|
||||||
|
require.Equal(t, "CN", resp.CountryCode)
|
||||||
|
require.Equal(t, "demo", resp.PaymentEnv)
|
||||||
|
require.Len(t, createRequests, 1)
|
||||||
|
require.Equal(t, "12.34", createRequests[0].Amount.StringFixed(2))
|
||||||
|
require.Equal(t, "CNY", createRequests[0].Currency)
|
||||||
|
require.Equal(t, "sub2_order", createRequests[0].MerchantOrderID)
|
||||||
|
require.Equal(t, airwallexDeterministicRequestID("payment-intent", "sub2_order", "12.34", "CNY"), createRequests[0].RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexCreatePaymentUsesConfiguredCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var createRequest airwallexCreatePaymentIntentRequest
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/authentication/login":
|
||||||
|
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/create":
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, json.Unmarshal(body, &createRequest))
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"HKD","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "whsec",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
"currency": "hkd",
|
||||||
|
"countryCode": "HK",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
prov.config["apiBase"] = server.URL + "/api/v1"
|
||||||
|
prov.httpClient = server.Client()
|
||||||
|
|
||||||
|
resp, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
|
||||||
|
OrderID: "sub2_order",
|
||||||
|
Amount: "12.34",
|
||||||
|
ReturnURL: "https://merchant.example.com/payment/result",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "HKD", createRequest.Currency)
|
||||||
|
require.Equal(t, "HKD", resp.Currency)
|
||||||
|
require.Equal(t, "HK", resp.CountryCode)
|
||||||
|
require.Equal(t, "HKD", prov.MerchantIdentityMetadata()["currency"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexRequestsUseConfiguredAccountID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
paRequestCount := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/authentication/login":
|
||||||
|
require.Equal(t, "acct_123", r.Header.Get("x-login-as"))
|
||||||
|
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/create":
|
||||||
|
paRequestCount++
|
||||||
|
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/int_123":
|
||||||
|
paRequestCount++
|
||||||
|
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"SUCCEEDED"}`))
|
||||||
|
case "/api/v1/pa/refunds/create":
|
||||||
|
paRequestCount++
|
||||||
|
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
|
||||||
|
_, _ = w.Write([]byte(`{"id":"ref_123","payment_intent_id":"int_123","amount":12.34,"currency":"CNY","status":"SETTLED"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/int_123/cancel":
|
||||||
|
paRequestCount++
|
||||||
|
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"CANCELLED"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "whsec",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
"accountId": "acct_123",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
prov.config["apiBase"] = server.URL + "/api/v1"
|
||||||
|
prov.httpClient = server.Client()
|
||||||
|
|
||||||
|
_, err = prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
|
||||||
|
OrderID: "sub2_order",
|
||||||
|
Amount: "12.34",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = prov.QueryOrder(context.Background(), "int_123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = prov.Refund(context.Background(), payment.RefundRequest{
|
||||||
|
TradeNo: "int_123",
|
||||||
|
Amount: "12.34",
|
||||||
|
Reason: "test refund",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, prov.CancelPayment(context.Background(), "int_123"))
|
||||||
|
require.Contains(t, prov.tokenCacheKey(), "acct_123")
|
||||||
|
require.Equal(t, 4, paRequestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexRefundRejectsUnsettledStatus(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, status := range []string{"RECEIVED", "ACCEPTED", "FAILED"} {
|
||||||
|
t.Run(status, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/authentication/login":
|
||||||
|
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
|
||||||
|
case "/api/v1/pa/refunds/create":
|
||||||
|
_, _ = w.Write([]byte(`{"id":"ref_123","payment_intent_id":"int_123","amount":12.34,"currency":"CNY","status":"` + status + `"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov := mustTestAirwallexProvider(t, server)
|
||||||
|
resp, err := prov.Refund(context.Background(), payment.RefundRequest{
|
||||||
|
TradeNo: "int_123",
|
||||||
|
Amount: "12.34",
|
||||||
|
Reason: "test refund",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.ErrorContains(t, err, "airwallex refund not settled")
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, "ref_123", resp.RefundID)
|
||||||
|
if status == airwallexRefundStatusFailed {
|
||||||
|
require.Equal(t, payment.ProviderStatusFailed, resp.Status)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, payment.ProviderStatusPending, resp.Status)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexAuthErrorIncludesCredentialGuidance(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/api/v1/authentication/login", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"code":"credentials_invalid","details":["Access Denied"],"message":"UNAUTHORIZED","source":""}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov := mustTestAirwallexProvider(t, server)
|
||||||
|
_, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
|
||||||
|
OrderID: "sub2_order",
|
||||||
|
Amount: "12.34",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.ErrorContains(t, err, "credentials_invalid")
|
||||||
|
require.ErrorContains(t, err, "API Base environment")
|
||||||
|
require.ErrorContains(t, err, "https://api-demo.airwallex.com/api/v1")
|
||||||
|
require.ErrorContains(t, err, "https://api.airwallex.com/api/v1")
|
||||||
|
require.ErrorContains(t, err, "Account ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexVerifyNotificationRequiresValidSignatureAndCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
prov, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "whsec",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
"accountId": "acct_123",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
raw := `{"id":"evt_1","name":"payment_intent.succeeded","accountId":"acct_123","data":{"object":{"id":"int_123","merchant_order_id":"sub2_abc","amount":88.66,"currency":"CNY","status":"SUCCEEDED"}}}`
|
||||||
|
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||||
|
headers := signedAirwallexHeaders(raw, timestamp, "whsec")
|
||||||
|
|
||||||
|
n, err := prov.VerifyNotification(context.Background(), raw, headers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, n)
|
||||||
|
require.Equal(t, "int_123", n.TradeNo)
|
||||||
|
require.Equal(t, "sub2_abc", n.OrderID)
|
||||||
|
require.Equal(t, payment.NotificationStatusSuccess, n.Status)
|
||||||
|
require.InDelta(t, 88.66, n.Amount, 0.0001)
|
||||||
|
require.Equal(t, "CNY", n.Metadata["currency"])
|
||||||
|
require.Equal(t, "acct_123", n.Metadata["account_id"])
|
||||||
|
|
||||||
|
headers["x-signature"] = strings.Repeat("0", 64)
|
||||||
|
_, err = prov.VerifyNotification(context.Background(), raw, headers)
|
||||||
|
require.ErrorContains(t, err, "invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAirwallexWebhookSignatureRejectsReplay(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
raw := `{"id":"evt_1"}`
|
||||||
|
timestamp := "1778241600000"
|
||||||
|
headers := signedAirwallexHeaders(raw, timestamp, "whsec")
|
||||||
|
err := verifyAirwallexWebhookSignature(raw, headers, "whsec", time.UnixMilli(1778241600000).Add(airwallexWebhookTolerance+time.Millisecond))
|
||||||
|
require.ErrorContains(t, err, "outside tolerance")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAirwallexQueryOrderMapsSucceeded(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/authentication/login":
|
||||||
|
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
|
||||||
|
case "/api/v1/pa/payment_intents/int_123":
|
||||||
|
_, _ = w.Write([]byte(`{"id":"int_123","amount":99.01,"currency":"CNY","merchant_order_id":"sub2_order","status":"SUCCEEDED"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
prov := mustTestAirwallexProvider(t, server)
|
||||||
|
resp, err := prov.QueryOrder(context.Background(), "int_123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, payment.ProviderStatusPaid, resp.Status)
|
||||||
|
require.InDelta(t, 99.01, resp.Amount, 0.0001)
|
||||||
|
require.Equal(t, "CNY", resp.Metadata["currency"])
|
||||||
|
require.Equal(t, "SUCCEEDED", resp.Metadata["status"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustTestAirwallexProvider(t *testing.T, server *httptest.Server) *Airwallex {
|
||||||
|
t.Helper()
|
||||||
|
prov, err := NewAirwallex("1", map[string]string{
|
||||||
|
"clientId": "cid",
|
||||||
|
"apiKey": "key",
|
||||||
|
"webhookSecret": "whsec",
|
||||||
|
"apiBase": airwallexDemoAPIBase,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
prov.config["apiBase"] = server.URL + "/api/v1"
|
||||||
|
prov.httpClient = server.Client()
|
||||||
|
return prov
|
||||||
|
}
|
||||||
|
|
||||||
|
func signedAirwallexHeaders(rawBody, timestamp, secret string) map[string]string {
|
||||||
|
mac := hmac.New(sha256.New, []byte(secret))
|
||||||
|
_, _ = mac.Write([]byte(timestamp))
|
||||||
|
_, _ = mac.Write([]byte(rawBody))
|
||||||
|
return map[string]string{
|
||||||
|
"x-timestamp": timestamp,
|
||||||
|
"x-signature": hex.EncodeToString(mac.Sum(nil)),
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,6 +17,8 @@ func CreateProvider(providerKey string, instanceID string, config map[string]str
|
|||||||
return NewWxpay(instanceID, config)
|
return NewWxpay(instanceID, config)
|
||||||
case payment.TypeStripe:
|
case payment.TypeStripe:
|
||||||
return NewStripe(instanceID, config)
|
return NewStripe(instanceID, config)
|
||||||
|
case payment.TypeAirwallex:
|
||||||
|
return NewAirwallex(instanceID, config)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown provider key: %s", providerKey)
|
return nil, fmt.Errorf("unknown provider key: %s", providerKey)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
// Stripe constants.
|
// Stripe constants.
|
||||||
const (
|
const (
|
||||||
stripeCurrency = "cny"
|
|
||||||
stripeEventPaymentSuccess = "payment_intent.succeeded"
|
stripeEventPaymentSuccess = "payment_intent.succeeded"
|
||||||
stripeEventPaymentFailed = "payment_intent.payment_failed"
|
stripeEventPaymentFailed = "payment_intent.payment_failed"
|
||||||
)
|
)
|
||||||
@ -34,9 +33,15 @@ func NewStripe(instanceID string, config map[string]string) (*Stripe, error) {
|
|||||||
if config["secretKey"] == "" {
|
if config["secretKey"] == "" {
|
||||||
return nil, fmt.Errorf("stripe config missing required key: secretKey")
|
return nil, fmt.Errorf("stripe config missing required key: secretKey")
|
||||||
}
|
}
|
||||||
|
cfg := cloneStringMap(config)
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stripe config currency: %w", err)
|
||||||
|
}
|
||||||
|
cfg["currency"] = currency
|
||||||
return &Stripe{
|
return &Stripe{
|
||||||
instanceID: instanceID,
|
instanceID: instanceID,
|
||||||
config: config,
|
config: cfg,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,6 +65,24 @@ func (s *Stripe) SupportedTypes() []payment.PaymentType {
|
|||||||
return []payment.PaymentType{payment.TypeStripe}
|
return []payment.PaymentType{payment.TypeStripe}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Stripe) MerchantIdentityMetadata() map[string]string {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]string{"currency": s.currency()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stripe) currency() string {
|
||||||
|
if s == nil {
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(s.config["currency"])
|
||||||
|
if err != nil {
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
return currency
|
||||||
|
}
|
||||||
|
|
||||||
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
|
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
|
||||||
var stripePaymentMethodTypes = map[string][]string{
|
var stripePaymentMethodTypes = map[string][]string{
|
||||||
payment.TypeCard: {"card"},
|
payment.TypeCard: {"card"},
|
||||||
@ -72,7 +95,8 @@ var stripePaymentMethodTypes = map[string][]string{
|
|||||||
func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||||
s.ensureInit()
|
s.ensureInit()
|
||||||
|
|
||||||
amountInCents, err := payment.YuanToFen(req.Amount)
|
currency := s.currency()
|
||||||
|
amountInMinorUnit, err := payment.AmountToMinorUnit(req.Amount, currency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("stripe create payment: %w", err)
|
return nil, fmt.Errorf("stripe create payment: %w", err)
|
||||||
}
|
}
|
||||||
@ -86,8 +110,8 @@ func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := &stripe.PaymentIntentCreateParams{
|
params := &stripe.PaymentIntentCreateParams{
|
||||||
Amount: stripe.Int64(amountInCents),
|
Amount: stripe.Int64(amountInMinorUnit),
|
||||||
Currency: stripe.String(stripeCurrency),
|
Currency: stripe.String(strings.ToLower(currency)),
|
||||||
PaymentMethodTypes: pmTypes,
|
PaymentMethodTypes: pmTypes,
|
||||||
Description: stripe.String(req.Subject),
|
Description: stripe.String(req.Subject),
|
||||||
Metadata: map[string]string{"orderId": req.OrderID},
|
Metadata: map[string]string{"orderId": req.OrderID},
|
||||||
@ -113,6 +137,7 @@ func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
|
|||||||
return &payment.CreatePaymentResponse{
|
return &payment.CreatePaymentResponse{
|
||||||
TradeNo: pi.ID,
|
TradeNo: pi.ID,
|
||||||
ClientSecret: pi.ClientSecret,
|
ClientSecret: pi.ClientSecret,
|
||||||
|
Currency: currency,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,10 +158,14 @@ func (s *Stripe) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
|
|||||||
status = payment.ProviderStatusFailed
|
status = payment.ProviderStatusFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
currency := stripeIntentCurrency(pi.Currency, s.currency())
|
||||||
return &payment.QueryOrderResponse{
|
return &payment.QueryOrderResponse{
|
||||||
TradeNo: pi.ID,
|
TradeNo: pi.ID,
|
||||||
Status: status,
|
Status: status,
|
||||||
Amount: payment.FenToYuan(pi.Amount),
|
Amount: payment.MinorUnitToAmount(pi.Amount, currency),
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"currency": currency,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,12 +203,16 @@ func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string
|
|||||||
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
|
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
|
||||||
return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
|
return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
|
||||||
}
|
}
|
||||||
|
currency := stripeIntentCurrency(pi.Currency, payment.DefaultPaymentCurrency)
|
||||||
return &payment.PaymentNotification{
|
return &payment.PaymentNotification{
|
||||||
TradeNo: pi.ID,
|
TradeNo: pi.ID,
|
||||||
OrderID: pi.Metadata["orderId"],
|
OrderID: pi.Metadata["orderId"],
|
||||||
Amount: payment.FenToYuan(pi.Amount),
|
Amount: payment.MinorUnitToAmount(pi.Amount, currency),
|
||||||
Status: status,
|
Status: status,
|
||||||
RawData: rawBody,
|
RawData: rawBody,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"currency": currency,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,14 +220,14 @@ func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string
|
|||||||
func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||||
s.ensureInit()
|
s.ensureInit()
|
||||||
|
|
||||||
amountInCents, err := payment.YuanToFen(req.Amount)
|
amountInMinorUnit, err := payment.AmountToMinorUnit(req.Amount, s.currency())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("stripe refund: %w", err)
|
return nil, fmt.Errorf("stripe refund: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
params := &stripe.RefundCreateParams{
|
params := &stripe.RefundCreateParams{
|
||||||
PaymentIntent: stripe.String(req.TradeNo),
|
PaymentIntent: stripe.String(req.TradeNo),
|
||||||
Amount: stripe.Int64(amountInCents),
|
Amount: stripe.Int64(amountInMinorUnit),
|
||||||
Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
|
Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
|
||||||
}
|
}
|
||||||
params.Context = ctx
|
params.Context = ctx
|
||||||
@ -215,6 +248,18 @@ func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*paymen
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stripeIntentCurrency(raw stripe.Currency, fallback string) string {
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(string(raw))
|
||||||
|
if err != nil || currency == payment.DefaultPaymentCurrency && strings.TrimSpace(string(raw)) == "" {
|
||||||
|
normalizedFallback, fallbackErr := payment.NormalizePaymentCurrency(fallback)
|
||||||
|
if fallbackErr == nil {
|
||||||
|
return normalizedFallback
|
||||||
|
}
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
return currency
|
||||||
|
}
|
||||||
|
|
||||||
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
|
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
|
||||||
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
|
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
|
||||||
func resolveStripeMethodTypes(instanceSubMethods string) []string {
|
func resolveStripeMethodTypes(instanceSubMethods string) []string {
|
||||||
@ -257,6 +302,7 @@ func (s *Stripe) CancelPayment(ctx context.Context, tradeNo string) error {
|
|||||||
|
|
||||||
// Ensure interface compliance.
|
// Ensure interface compliance.
|
||||||
var (
|
var (
|
||||||
_ payment.Provider = (*Stripe)(nil)
|
_ payment.Provider = (*Stripe)(nil)
|
||||||
_ payment.CancelableProvider = (*Stripe)(nil)
|
_ payment.CancelableProvider = (*Stripe)(nil)
|
||||||
|
_ payment.MerchantIdentityProvider = (*Stripe)(nil)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ const (
|
|||||||
TypeCard PaymentType = "card"
|
TypeCard PaymentType = "card"
|
||||||
TypeLink PaymentType = "link"
|
TypeLink PaymentType = "link"
|
||||||
TypeEasyPay PaymentType = "easypay"
|
TypeEasyPay PaymentType = "easypay"
|
||||||
|
TypeAirwallex PaymentType = "airwallex"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Order status constants shared across payment and service layers.
|
// Order status constants shared across payment and service layers.
|
||||||
@ -82,6 +83,8 @@ func GetBasePaymentType(t string) string {
|
|||||||
switch {
|
switch {
|
||||||
case t == TypeEasyPay:
|
case t == TypeEasyPay:
|
||||||
return TypeEasyPay
|
return TypeEasyPay
|
||||||
|
case t == TypeAirwallex:
|
||||||
|
return TypeAirwallex
|
||||||
case t == TypeStripe || t == TypeCard || t == TypeLink:
|
case t == TypeStripe || t == TypeCard || t == TypeLink:
|
||||||
return TypeStripe
|
return TypeStripe
|
||||||
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
|
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
|
||||||
@ -96,7 +99,7 @@ func GetBasePaymentType(t string) string {
|
|||||||
// CreatePaymentRequest holds the parameters for creating a new payment.
|
// CreatePaymentRequest holds the parameters for creating a new payment.
|
||||||
type CreatePaymentRequest struct {
|
type CreatePaymentRequest struct {
|
||||||
OrderID string // Internal order ID
|
OrderID string // Internal order ID
|
||||||
Amount string // Pay amount in CNY (formatted to 2 decimal places)
|
Amount string // 支付金额,按服务商实例配置的币种解释
|
||||||
PaymentType string // e.g. "alipay", "wxpay", "stripe"
|
PaymentType string // e.g. "alipay", "wxpay", "stripe"
|
||||||
Subject string // Product description
|
Subject string // Product description
|
||||||
NotifyURL string // Webhook callback URL
|
NotifyURL string // Webhook callback URL
|
||||||
@ -141,7 +144,11 @@ type CreatePaymentResponse struct {
|
|||||||
TradeNo string // Third-party transaction ID
|
TradeNo string // Third-party transaction ID
|
||||||
PayURL string // H5 payment URL (alipay/wxpay)
|
PayURL string // H5 payment URL (alipay/wxpay)
|
||||||
QRCode string // QR code content for scanning
|
QRCode string // QR code content for scanning
|
||||||
ClientSecret string // Stripe PaymentIntent client secret
|
ClientSecret string // Stripe PaymentIntent 客户端密钥
|
||||||
|
IntentID string // 前端 SDK 需要的服务商支付意图 ID
|
||||||
|
Currency string // 服务商支付币种
|
||||||
|
CountryCode string // 服务商收银台国家/地区代码
|
||||||
|
PaymentEnv string // 服务商前端环境标识
|
||||||
ResultType CreatePaymentResultType // Typed result contract for frontend flows
|
ResultType CreatePaymentResultType // Typed result contract for frontend flows
|
||||||
OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
|
OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
|
||||||
JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
|
JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
|
||||||
@ -151,7 +158,7 @@ type CreatePaymentResponse struct {
|
|||||||
type QueryOrderResponse struct {
|
type QueryOrderResponse struct {
|
||||||
TradeNo string
|
TradeNo string
|
||||||
Status string // "pending", "paid", "failed", "refunded"
|
Status string // "pending", "paid", "failed", "refunded"
|
||||||
Amount float64 // Amount in CNY
|
Amount float64 // 按服务商返回币种解释的金额
|
||||||
PaidAt string // RFC3339 timestamp or empty
|
PaidAt string // RFC3339 timestamp or empty
|
||||||
Metadata map[string]string
|
Metadata map[string]string
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,7 +46,7 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
|
|||||||
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
@ -440,7 +440,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
|
|||||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||||
reqBody := LoadCodeAssistRequest{}
|
reqBody := LoadCodeAssistRequest{}
|
||||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||||
reqBody.Metadata.IDEVersion = "1.20.6"
|
reqBody.Metadata.IDEVersion = GetUserAgentVersionForContext(ctx)
|
||||||
reqBody.Metadata.IDEName = "antigravity"
|
reqBody.Metadata.IDEName = "antigravity"
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(reqBody)
|
bodyBytes, err := json.Marshal(reqBody)
|
||||||
@ -461,7 +461,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -540,7 +540,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s
|
|||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -674,7 +674,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -792,7 +792,7 @@ func (c *Client) SetUserSettings(ctx context.Context, accessToken string) (*SetU
|
|||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Accept", "*/*")
|
req.Header.Set("Accept", "*/*")
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1")
|
req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1")
|
||||||
req.Host = "daily-cloudcode-pa.googleapis.com"
|
req.Host = "daily-cloudcode-pa.googleapis.com"
|
||||||
|
|
||||||
@ -835,7 +835,7 @@ func (c *Client) FetchUserInfo(ctx context.Context, accessToken, projectID strin
|
|||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Accept", "*/*")
|
req.Header.Set("Accept", "*/*")
|
||||||
req.Header.Set("User-Agent", GetUserAgent())
|
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1")
|
req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1")
|
||||||
req.Host = "daily-cloudcode-pa.googleapis.com"
|
req.Host = "daily-cloudcode-pa.googleapis.com"
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package antigravity
|
package antigravity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@ -9,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -28,6 +30,12 @@ const (
|
|||||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
|
// AntigravityUserAgentVersionEnv 是 Antigravity User-Agent 版本号的环境变量名。
|
||||||
|
AntigravityUserAgentVersionEnv = "ANTIGRAVITY_USER_AGENT_VERSION"
|
||||||
|
|
||||||
|
// DefaultUserAgentVersion 是未通过环境变量或后台设置覆盖时使用的默认版本号。
|
||||||
|
DefaultUserAgentVersion = "1.23.2"
|
||||||
|
|
||||||
// 固定的 redirect_uri(用户需手动复制 code)
|
// 固定的 redirect_uri(用户需手动复制 code)
|
||||||
RedirectURI = "http://localhost:8085/callback"
|
RedirectURI = "http://localhost:8085/callback"
|
||||||
|
|
||||||
@ -49,15 +57,24 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
var userAgentVersionPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
|
||||||
var defaultUserAgentVersion = "1.21.9"
|
|
||||||
|
// UserAgentVersionResolver 提供运行时 User-Agent 版本号覆盖能力。
|
||||||
|
type UserAgentVersionResolver func(ctx context.Context) string
|
||||||
|
|
||||||
|
var (
|
||||||
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置。
|
||||||
|
defaultUserAgentVersion = DefaultUserAgentVersion
|
||||||
|
userAgentVersionMu sync.RWMutex
|
||||||
|
userAgentVersionResolver UserAgentVersionResolver
|
||||||
|
)
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
var defaultClientSecret = "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"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// 从环境变量读取版本号,未设置则使用默认值
|
// 从环境变量读取版本号,未设置则使用默认值
|
||||||
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
if version := NormalizeUserAgentVersion(os.Getenv(AntigravityUserAgentVersionEnv)); version != "" {
|
||||||
defaultUserAgentVersion = version
|
defaultUserAgentVersion = version
|
||||||
}
|
}
|
||||||
// 从环境变量读取 client_secret,未设置则使用默认值
|
// 从环境变量读取 client_secret,未设置则使用默认值
|
||||||
@ -66,11 +83,61 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserAgent 返回当前配置的 User-Agent
|
// NormalizeUserAgentVersion 校验并归一化 Antigravity User-Agent 版本号。
|
||||||
func GetUserAgent() string {
|
func NormalizeUserAgentVersion(version string) string {
|
||||||
|
version = strings.TrimSpace(version)
|
||||||
|
if version == "" || !userAgentVersionPattern.MatchString(version) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultUserAgentVersion 返回配置文件/环境变量层面的默认版本号。
|
||||||
|
func GetDefaultUserAgentVersion() string {
|
||||||
|
return defaultUserAgentVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserAgentVersionResolver 设置运行时版本号解析器,通常由后台 settings 注入。
|
||||||
|
func SetUserAgentVersionResolver(resolver UserAgentVersionResolver) {
|
||||||
|
userAgentVersionMu.Lock()
|
||||||
|
defer userAgentVersionMu.Unlock()
|
||||||
|
userAgentVersionResolver = resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAgentVersionForContext 返回当前请求应使用的 Antigravity 版本号。
|
||||||
|
func GetUserAgentVersionForContext(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
userAgentVersionMu.RLock()
|
||||||
|
resolver := userAgentVersionResolver
|
||||||
|
userAgentVersionMu.RUnlock()
|
||||||
|
if resolver != nil {
|
||||||
|
if version := NormalizeUserAgentVersion(resolver(ctx)); version != "" {
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultUserAgentVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildUserAgent 使用指定版本号构造 User-Agent;版本为空或非法时回退默认值。
|
||||||
|
func BuildUserAgent(version string) string {
|
||||||
|
if normalized := NormalizeUserAgentVersion(version); normalized != "" {
|
||||||
|
return fmt.Sprintf("antigravity/%s windows/amd64", normalized)
|
||||||
|
}
|
||||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserAgentForContext 返回当前请求应使用的 User-Agent。
|
||||||
|
func GetUserAgentForContext(ctx context.Context) string {
|
||||||
|
return BuildUserAgent(GetUserAgentVersionForContext(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAgent 返回当前配置的 User-Agent。
|
||||||
|
func GetUserAgent() string {
|
||||||
|
return GetUserAgentForContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
func getClientSecret() (string, error) {
|
func getClientSecret() (string, error) {
|
||||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||||
return v, nil
|
return v, nil
|
||||||
|
|||||||
@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.21.9 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.23.2 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@ -15,8 +15,8 @@ import (
|
|||||||
utls "github.com/refraction-networking/utls"
|
utls "github.com/refraction-networking/utls"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CapturedFingerprint mirrors the Fingerprint struct from tls-fingerprint-web.
|
// CapturedFingerprint 对应 tls-fingerprint-web 返回的 Fingerprint 结构。
|
||||||
// Used to deserialize the JSON response from the capture server.
|
// 用于反序列化 capture server 的 JSON 响应。
|
||||||
type CapturedFingerprint struct {
|
type CapturedFingerprint struct {
|
||||||
JA3Raw string `json:"ja3_raw"`
|
JA3Raw string `json:"ja3_raw"`
|
||||||
JA3Hash string `json:"ja3_hash"`
|
JA3Hash string `json:"ja3_hash"`
|
||||||
@ -35,17 +35,17 @@ type CapturedFingerprint struct {
|
|||||||
EnableGREASE bool `json:"enable_grease"`
|
EnableGREASE bool `json:"enable_grease"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDialerAgainstCaptureServer connects to the tls-fingerprint-web capture server
|
// TestDialerAgainstCaptureServer 连接 tls-fingerprint-web capture server,
|
||||||
// and verifies that the dialer's TLS fingerprint matches the configured Profile.
|
// 验证 Dialer 的 TLS 指纹是否匹配配置的 Profile。
|
||||||
//
|
//
|
||||||
// Default capture server: https://tls.sub2api.org:8090
|
// 该测试依赖外部服务,默认跳过。需要手动验证时设置:
|
||||||
// Override with env: TLSFINGERPRINT_CAPTURE_URL=https://localhost:8443
|
// TLSFINGERPRINT_CAPTURE_URL=https://localhost:8443
|
||||||
//
|
//
|
||||||
// Run: go test -v -run TestDialerAgainstCaptureServer ./internal/pkg/tlsfingerprint/...
|
// 运行方式:go test -tags=integration -v -run TestDialerAgainstCaptureServer ./internal/pkg/tlsfingerprint/...
|
||||||
func TestDialerAgainstCaptureServer(t *testing.T) {
|
func TestDialerAgainstCaptureServer(t *testing.T) {
|
||||||
captureURL := os.Getenv("TLSFINGERPRINT_CAPTURE_URL")
|
captureURL := strings.TrimSpace(os.Getenv("TLSFINGERPRINT_CAPTURE_URL"))
|
||||||
if captureURL == "" {
|
if captureURL == "" {
|
||||||
captureURL = "https://tls.sub2api.org:8090"
|
t.Skip("跳过外部 TLS 指纹 capture 测试:未设置 TLSFINGERPRINT_CAPTURE_URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -57,7 +57,7 @@ func TestDialerAgainstCaptureServer(t *testing.T) {
|
|||||||
profile: &Profile{
|
profile: &Profile{
|
||||||
Name: "default",
|
Name: "default",
|
||||||
EnableGREASE: false,
|
EnableGREASE: false,
|
||||||
// All empty → uses built-in defaults
|
// 全部留空时使用内置默认值
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -104,7 +104,7 @@ func TestDialerAgainstCaptureServer(t *testing.T) {
|
|||||||
t.Logf("JA3 Hash: %s", captured.JA3Hash)
|
t.Logf("JA3 Hash: %s", captured.JA3Hash)
|
||||||
t.Logf("JA4: %s", captured.JA4)
|
t.Logf("JA4: %s", captured.JA4)
|
||||||
|
|
||||||
// Resolve effective profile values (what the dialer actually uses)
|
// 解析实际生效的 Profile 值,也就是 Dialer 最终使用的值。
|
||||||
effectiveCipherSuites := tc.profile.CipherSuites
|
effectiveCipherSuites := tc.profile.CipherSuites
|
||||||
if len(effectiveCipherSuites) == 0 {
|
if len(effectiveCipherSuites) == 0 {
|
||||||
effectiveCipherSuites = defaultCipherSuites
|
effectiveCipherSuites = defaultCipherSuites
|
||||||
@ -144,7 +144,7 @@ func TestDialerAgainstCaptureServer(t *testing.T) {
|
|||||||
effectivePSKModes = []uint16{1} // psk_dhe_ke
|
effectivePSKModes = []uint16{1} // psk_dhe_ke
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify each field
|
// 校验每个指纹字段
|
||||||
assertIntSliceEqual(t, "cipher_suites", uint16sToInts(effectiveCipherSuites), captured.CipherSuites)
|
assertIntSliceEqual(t, "cipher_suites", uint16sToInts(effectiveCipherSuites), captured.CipherSuites)
|
||||||
assertIntSliceEqual(t, "curves", uint16sToInts(effectiveCurves), captured.Curves)
|
assertIntSliceEqual(t, "curves", uint16sToInts(effectiveCurves), captured.Curves)
|
||||||
assertIntSliceEqual(t, "point_formats", uint16sToInts(effectivePointFormats), captured.PointFormats)
|
assertIntSliceEqual(t, "point_formats", uint16sToInts(effectivePointFormats), captured.PointFormats)
|
||||||
@ -160,13 +160,13 @@ func TestDialerAgainstCaptureServer(t *testing.T) {
|
|||||||
t.Logf(" enable_grease: %v OK", captured.EnableGREASE)
|
t.Logf(" enable_grease: %v OK", captured.EnableGREASE)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify extension order
|
// 校验扩展顺序;如果 Profile 显式配置了 Extensions 就使用配置值,
|
||||||
// Use profile.Extensions if set, otherwise the default order (Node.js 24.x)
|
// 否则使用默认顺序(Node.js 24.x)。
|
||||||
expectedExtOrder := uint16sToInts(defaultExtensionOrder)
|
expectedExtOrder := uint16sToInts(defaultExtensionOrder)
|
||||||
if len(tc.profile.Extensions) > 0 {
|
if len(tc.profile.Extensions) > 0 {
|
||||||
expectedExtOrder = uint16sToInts(tc.profile.Extensions)
|
expectedExtOrder = uint16sToInts(tc.profile.Extensions)
|
||||||
}
|
}
|
||||||
// Strip GREASE values from both expected and captured for comparison
|
// 比较前从期望值和采集值中剔除 GREASE。
|
||||||
var filteredExpected, filteredActual []int
|
var filteredExpected, filteredActual []int
|
||||||
for _, e := range expectedExtOrder {
|
for _, e := range expectedExtOrder {
|
||||||
if !isGREASEValue(uint16(e)) {
|
if !isGREASEValue(uint16(e)) {
|
||||||
@ -180,7 +180,7 @@ func TestDialerAgainstCaptureServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assertIntSliceEqual(t, "extensions (order, non-GREASE)", filteredExpected, filteredActual)
|
assertIntSliceEqual(t, "extensions (order, non-GREASE)", filteredExpected, filteredActual)
|
||||||
|
|
||||||
// Print full captured data as JSON for debugging
|
// 打印完整采集结果,便于排查指纹差异。
|
||||||
capturedJSON, _ := json.MarshalIndent(captured, " ", " ")
|
capturedJSON, _ := json.MarshalIndent(captured, " ", " ")
|
||||||
t.Logf("Full captured fingerprint:\n %s", string(capturedJSON))
|
t.Logf("Full captured fingerprint:\n %s", string(capturedJSON))
|
||||||
})
|
})
|
||||||
|
|||||||
91
backend/internal/repository/http_upstream_antigravity.go
Normal file
91
backend/internal/repository/http_upstream_antigravity.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
// ==============================================================
|
||||||
|
// antigravity — Go 原生 TLS 指纹扩展
|
||||||
|
//
|
||||||
|
// 此文件包含 Antigravity fork 新增的 TLS 指纹代理功能,
|
||||||
|
// 与 upstream 代码完全隔离,便于 upstream 更新时的合并维护。
|
||||||
|
//
|
||||||
|
// 上游文件 http_upstream.go 中的钩子调用点:
|
||||||
|
// Do() — 匹配主机时路由到 doWithTLSFingerprint
|
||||||
|
// DoWithTLS() — profile==nil 时回退到 Do(),触发同样的路由
|
||||||
|
//
|
||||||
|
// 替代原先的 Node.js TLS 代理(node-tls-proxy),
|
||||||
|
// 直接使用 Go utls 库模拟 Claude CLI 的 TLS 指纹。
|
||||||
|
// ==============================================================
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isTLSFingerprintRoutingEnabled 检查 TLS 指纹路由是否启用
|
||||||
|
// 使用 TLSFingerprint.Enabled 配置项(而不是旧的 NodeTLSProxy.Enabled)
|
||||||
|
func (s *httpUpstreamService) isTLSFingerprintRoutingEnabled() bool {
|
||||||
|
if s.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.cfg.Gateway.TLSFingerprint.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldRouteWithTLSFingerprint 判断请求是否应该使用 TLS 指纹
|
||||||
|
// 拦截目标主机在 proxy_hosts 白名单中的 HTTPS 请求
|
||||||
|
// 白名单为空时默认代理 api.anthropic.com 和 Antigravity API 主机
|
||||||
|
func (s *httpUpstreamService) shouldRouteWithTLSFingerprint(req *http.Request) bool {
|
||||||
|
if req == nil || req.URL == nil || req.URL.Scheme != "https" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
reqHost := req.URL.Hostname()
|
||||||
|
if reqHost == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts := s.cfg.Gateway.NodeTLSProxy.ProxyHosts
|
||||||
|
if len(hosts) == 0 {
|
||||||
|
// 默认白名单:api.anthropic.com 和 Antigravity API 主机
|
||||||
|
defaultHosts := map[string]bool{
|
||||||
|
"api.anthropic.com": true,
|
||||||
|
"cloudcode-pa.googleapis.com": true,
|
||||||
|
"daily-cloudcode-pa.googleapis.com": true,
|
||||||
|
}
|
||||||
|
return defaultHosts[reqHost]
|
||||||
|
}
|
||||||
|
for _, h := range hosts {
|
||||||
|
if reqHost == h {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultTLSProfile 返回模拟 Claude CLI (Node.js 24.x) 的默认 TLS 指纹配置
|
||||||
|
// 所有 slice 字段留空 → dialer.go 自动使用内置的 Node.js 24.x 默认值
|
||||||
|
// ALPN 仅声明 http/1.1,与真实 CLI 行为一致(undici allowH2=false)
|
||||||
|
func defaultTLSProfile() *tlsfingerprint.Profile {
|
||||||
|
return &tlsfingerprint.Profile{
|
||||||
|
Name: "claude_cli_builtin",
|
||||||
|
EnableGREASE: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// doWithTLSFingerprint 使用 Go 原生 utls TLS 指纹发送请求
|
||||||
|
// 直接通过 DoWithTLS 路径,利用已有的 utls dialer 基础设施:
|
||||||
|
// - 直连:Dialer (TCP → utls handshake)
|
||||||
|
// - HTTP 代理:HTTPProxyDialer (CONNECT 隧道 → utls handshake)
|
||||||
|
// - SOCKS5 代理:SOCKS5ProxyDialer (SOCKS5 隧道 → utls handshake)
|
||||||
|
func (s *httpUpstreamService) doWithTLSFingerprint(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
|
proxyInfo := "direct"
|
||||||
|
if proxyURL != "" {
|
||||||
|
proxyInfo = logredact.RedactProxyURL(proxyURL)
|
||||||
|
}
|
||||||
|
slog.Debug("tls_fingerprint_routing",
|
||||||
|
"account_id", accountID,
|
||||||
|
"target", req.URL.Host,
|
||||||
|
"proxy", proxyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
return s.DoWithTLS(req, proxyURL, accountID, accountConcurrency, defaultTLSProfile())
|
||||||
|
}
|
||||||
@ -41,13 +41,13 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
|
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
|
||||||
// 验证未配置时使用 300 秒默认值
|
// 验证未配置时使用 600 秒默认值(LLM 排队较久,本地从 300s 提至 600s)
|
||||||
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||||
svc := s.newService()
|
svc := s.newService()
|
||||||
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
|
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
|
||||||
transport, ok := entry.client.Transport.(*http.Transport)
|
transport, ok := entry.client.Transport.(*http.Transport)
|
||||||
require.True(s.T(), ok, "expected *http.Transport")
|
require.True(s.T(), ok, "expected *http.Transport")
|
||||||
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
require.Equal(s.T(), 600*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
|
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
|
||||||
|
|||||||
@ -773,6 +773,8 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"backend_mode_enabled": false,
|
"backend_mode_enabled": false,
|
||||||
"enable_cch_signing": false,
|
"enable_cch_signing": false,
|
||||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||||
|
"rewrite_message_cache_control": false,
|
||||||
|
"antigravity_user_agent_version": "",
|
||||||
"enable_fingerprint_unification": true,
|
"enable_fingerprint_unification": true,
|
||||||
"enable_metadata_passthrough": false,
|
"enable_metadata_passthrough": false,
|
||||||
"web_search_emulation_enabled": false,
|
"web_search_emulation_enabled": false,
|
||||||
@ -988,6 +990,8 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"enable_metadata_passthrough": false,
|
"enable_metadata_passthrough": false,
|
||||||
"enable_cch_signing": false,
|
"enable_cch_signing": false,
|
||||||
"enable_anthropic_cache_ttl_1h_injection": false,
|
"enable_anthropic_cache_ttl_1h_injection": false,
|
||||||
|
"rewrite_message_cache_control": false,
|
||||||
|
"antigravity_user_agent_version": "",
|
||||||
"web_search_emulation_enabled": false,
|
"web_search_emulation_enabled": false,
|
||||||
"payment_visible_method_alipay_source": "",
|
"payment_visible_method_alipay_source": "",
|
||||||
"payment_visible_method_wxpay_source": "",
|
"payment_visible_method_wxpay_source": "",
|
||||||
|
|||||||
@ -20,8 +20,35 @@ const (
|
|||||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||||
// StripeDomain is the domain for Stripe.js SDK
|
// StripeDomain is the domain for Stripe.js SDK
|
||||||
StripeDomain = "https://*.stripe.com"
|
StripeDomain = "https://*.stripe.com"
|
||||||
|
// AirwallexStaticDomain 是 Airwallex 生产环境 SDK 脚本域名。
|
||||||
|
AirwallexStaticDomain = "https://static.airwallex.com"
|
||||||
|
// AirwallexCheckoutDomain 是 Airwallex 生产环境收银台元素和 iframe 域名。
|
||||||
|
AirwallexCheckoutDomain = "https://checkout.airwallex.com"
|
||||||
|
// AirwallexDemoStaticDomain 是 Airwallex 沙箱环境 SDK 脚本域名。
|
||||||
|
AirwallexDemoStaticDomain = "https://static-demo.airwallex.com"
|
||||||
|
// AirwallexDemoCheckoutDomain 是 Airwallex 沙箱环境收银台元素和 iframe 域名。
|
||||||
|
AirwallexDemoCheckoutDomain = "https://checkout-demo.airwallex.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var requiredCSPDirectiveValues = []struct {
|
||||||
|
directive string
|
||||||
|
value string
|
||||||
|
}{
|
||||||
|
{"script-src", CloudflareInsightsDomain},
|
||||||
|
{"script-src", StripeDomain},
|
||||||
|
{"frame-src", StripeDomain},
|
||||||
|
{"script-src", AirwallexStaticDomain},
|
||||||
|
{"script-src", AirwallexCheckoutDomain},
|
||||||
|
{"style-src", AirwallexStaticDomain},
|
||||||
|
{"style-src", AirwallexCheckoutDomain},
|
||||||
|
{"frame-src", AirwallexCheckoutDomain},
|
||||||
|
{"script-src", AirwallexDemoStaticDomain},
|
||||||
|
{"script-src", AirwallexDemoCheckoutDomain},
|
||||||
|
{"style-src", AirwallexDemoStaticDomain},
|
||||||
|
{"style-src", AirwallexDemoCheckoutDomain},
|
||||||
|
{"frame-src", AirwallexDemoCheckoutDomain},
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateNonce generates a cryptographically secure random nonce.
|
// GenerateNonce generates a cryptographically secure random nonce.
|
||||||
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
|
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
|
||||||
func GenerateNonce() (string, error) {
|
func GenerateNonce() (string, error) {
|
||||||
@ -100,29 +127,39 @@ func isAPIRoutePath(c *gin.Context) bool {
|
|||||||
strings.HasPrefix(path, "/images")
|
strings.HasPrefix(path, "/images")
|
||||||
}
|
}
|
||||||
|
|
||||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
|
// enhanceCSPPolicy 确保 CSP 策略包含 nonce 支持和支付 SDK 必需域名。
|
||||||
// and Stripe.js domains. This allows the application to work correctly even if the
|
// 这样旧配置文件没有及时补域名时,前端支付组件仍能正常加载。
|
||||||
// config file has an older CSP policy.
|
|
||||||
func enhanceCSPPolicy(policy string) string {
|
func enhanceCSPPolicy(policy string) string {
|
||||||
// Add nonce placeholder to script-src if not present
|
// Add nonce placeholder to script-src if not present
|
||||||
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
||||||
policy = addToDirective(policy, "script-src", NonceTemplate)
|
policy = addToDirective(policy, "script-src", NonceTemplate)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Cloudflare Insights domain to script-src if not present
|
for _, required := range requiredCSPDirectiveValues {
|
||||||
if !strings.Contains(policy, CloudflareInsightsDomain) {
|
if !directiveHasValue(policy, required.directive, required.value) {
|
||||||
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
policy = addToDirective(policy, required.directive, required.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Stripe.js domain to script-src and frame-src if not present
|
|
||||||
if !strings.Contains(policy, "stripe.com") {
|
|
||||||
policy = addToDirective(policy, "script-src", StripeDomain)
|
|
||||||
policy = addToDirective(policy, "frame-src", StripeDomain)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func directiveHasValue(policy, directive, value string) bool {
|
||||||
|
for _, rawDirective := range strings.Split(policy, ";") {
|
||||||
|
fields := strings.Fields(strings.TrimSpace(rawDirective))
|
||||||
|
if len(fields) == 0 || fields[0] != directive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, field := range fields[1:] {
|
||||||
|
if field == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// addToDirective adds a value to a specific CSP directive.
|
// addToDirective adds a value to a specific CSP directive.
|
||||||
// If the directive doesn't exist, it will be added after default-src.
|
// If the directive doesn't exist, it will be added after default-src.
|
||||||
func addToDirective(policy, directive, value string) string {
|
func addToDirective(policy, directive, value string) string {
|
||||||
|
|||||||
@ -330,6 +330,52 @@ func TestEnhanceCSPPolicy(t *testing.T) {
|
|||||||
assert.NotContains(t, enhanced, NonceTemplate)
|
assert.NotContains(t, enhanced, NonceTemplate)
|
||||||
assert.Contains(t, enhanced, "'nonce-existing'")
|
assert.Contains(t, enhanced, "'nonce-existing'")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("adds_airwallex_domains_for_payment_sdk", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__; style-src 'self'; frame-src 'self'"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
assert.Contains(t, enhanced, "script-src 'self' __CSP_NONCE__")
|
||||||
|
assert.Contains(t, enhanced, AirwallexStaticDomain)
|
||||||
|
assert.Contains(t, enhanced, AirwallexCheckoutDomain)
|
||||||
|
assert.Contains(t, enhanced, AirwallexDemoStaticDomain)
|
||||||
|
assert.Contains(t, enhanced, AirwallexDemoCheckoutDomain)
|
||||||
|
assert.Contains(t, enhanced, "style-src 'self'")
|
||||||
|
assert.Contains(t, enhanced, "frame-src 'self'")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does_not_duplicate_airwallex_domains", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self' https://static.airwallex.com https://static-demo.airwallex.com; frame-src https://checkout.airwallex.com https://checkout-demo.airwallex.com"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexStaticDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexCheckoutDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexStaticDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexCheckoutDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "frame-src", AirwallexCheckoutDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexDemoStaticDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexDemoCheckoutDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexDemoStaticDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexDemoCheckoutDomain))
|
||||||
|
assert.Equal(t, 1, countDirectiveValue(enhanced, "frame-src", AirwallexDemoCheckoutDomain))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func countDirectiveValue(policy, directive, value string) int {
|
||||||
|
for _, rawDirective := range strings.Split(policy, ";") {
|
||||||
|
fields := strings.Fields(strings.TrimSpace(rawDirective))
|
||||||
|
if len(fields) == 0 || fields[0] != directive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
count := 0
|
||||||
|
for _, field := range fields[1:] {
|
||||||
|
if field == value {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddToDirective(t *testing.T) {
|
func TestAddToDirective(t *testing.T) {
|
||||||
|
|||||||
@ -37,6 +37,7 @@ func newGatewayRoutesTestRouter() *gin.Engine {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
&config.Config{},
|
&config.Config{},
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@ -62,6 +62,7 @@ func RegisterPaymentRoutes(
|
|||||||
webhook.POST("/alipay", webhookHandler.AlipayNotify)
|
webhook.POST("/alipay", webhookHandler.AlipayNotify)
|
||||||
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
|
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
|
||||||
webhook.POST("/stripe", webhookHandler.StripeWebhook)
|
webhook.POST("/stripe", webhookHandler.StripeWebhook)
|
||||||
|
webhook.POST("/airwallex", webhookHandler.AirwallexWebhook)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Admin payment endpoints (admin auth) ---
|
// --- Admin payment endpoints (admin auth) ---
|
||||||
|
|||||||
@ -774,6 +774,8 @@ func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, accoun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
persistOpenAI429PlanType(ctx, s.accountRepo, account, body)
|
||||||
|
|
||||||
var resetAt *time.Time
|
var resetAt *time.Time
|
||||||
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
|
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
|
||||||
resetAt = calculated
|
resetAt = calculated
|
||||||
|
|||||||
@ -61,12 +61,14 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
|||||||
|
|
||||||
type openAIAccountTestRepo struct {
|
type openAIAccountTestRepo struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
updatedExtra map[string]any
|
updatedExtra map[string]any
|
||||||
rateLimitedID int64
|
bulkUpdatedIDs []int64
|
||||||
rateLimitedAt *time.Time
|
bulkUpdatedPayload AccountBulkUpdate
|
||||||
clearedErrorID int64
|
rateLimitedID int64
|
||||||
setErrorID int64
|
rateLimitedAt *time.Time
|
||||||
setErrorMsg string
|
clearedErrorID int64
|
||||||
|
setErrorID int64
|
||||||
|
setErrorMsg string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||||
@ -74,6 +76,12 @@ func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *openAIAccountTestRepo) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||||
|
r.bulkUpdatedIDs = append([]int64(nil), ids...)
|
||||||
|
r.bulkUpdatedPayload = updates
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||||
r.rateLimitedID = id
|
r.rateLimitedID = id
|
||||||
r.rateLimitedAt = &resetAt
|
r.rateLimitedAt = &resetAt
|
||||||
@ -216,6 +224,33 @@ func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleErro
|
|||||||
require.Empty(t, repo.updatedExtra)
|
require.Empty(t, repo.updatedExtra)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_OpenAI429SyncsObservedPlanType(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
|
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","plan_type":"free","resets_at":1777283883}}`)
|
||||||
|
|
||||||
|
repo := &openAIAccountTestRepo{}
|
||||||
|
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||||
|
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 81,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "test-token", "plan_type": "plus"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, []int64{account.ID}, repo.bulkUpdatedIDs)
|
||||||
|
require.Equal(t, "free", repo.bulkUpdatedPayload.Credentials["plan_type"])
|
||||||
|
require.Equal(t, "free", account.Credentials["plan_type"])
|
||||||
|
require.Equal(t, account.ID, repo.rateLimitedID)
|
||||||
|
require.NotNil(t, account.RateLimitResetAt)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
|
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
ctx, _ := newTestContext()
|
ctx, _ := newTestContext()
|
||||||
|
|||||||
@ -2,8 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -110,6 +110,10 @@ type CostBreakdown struct {
|
|||||||
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrModelPricingUnavailable indicates that none of the configured pricing
|
||||||
|
// sources can price the requested model.
|
||||||
|
var ErrModelPricingUnavailable = errors.New("pricing not found")
|
||||||
|
|
||||||
// BillingService 计费服务
|
// BillingService 计费服务
|
||||||
type BillingService struct {
|
type BillingService struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@ -355,7 +359,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
|||||||
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
return nil, fmt.Errorf("%w for model: %s", ErrModelPricingUnavailable, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||||
@ -452,7 +456,7 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
|
|||||||
|
|
||||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||||
if pricing == nil {
|
if pricing == nil {
|
||||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
return nil, fmt.Errorf("no pricing available for model: %s: %w", input.Model, ErrModelPricingUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||||
|
|||||||
@ -371,6 +371,10 @@ const (
|
|||||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||||
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||||
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
|
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
|
||||||
|
// SettingKeyRewriteMessageCacheControl 是否改写 messages[*].content[*].cache_control(默认 false)
|
||||||
|
SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control"
|
||||||
|
// SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值)
|
||||||
|
SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version"
|
||||||
|
|
||||||
// Balance Low Notification
|
// Balance Low Notification
|
||||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||||
|
|||||||
@ -703,7 +703,7 @@ func TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadat
|
|||||||
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
"",
|
"",
|
||||||
sessionID,
|
sessionID,
|
||||||
claude.DefaultCLIProductVersion,
|
claude.DefaultCLIVersion,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -150,6 +150,21 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
|
|||||||
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_CountsToolsAndPreservesMessageAnchorsFirst(t *testing.T) {
|
||||||
|
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
||||||
|
|
||||||
|
result := enforceCacheControlLimit(body)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
|
||||||
|
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
|
||||||
|
require.True(t, gjson.GetBytes(result, "system.0.cache_control").Exists())
|
||||||
|
require.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control").Exists())
|
||||||
|
require.True(t, gjson.GetBytes(result, "messages.0.content.1.cache_control").Exists())
|
||||||
|
require.True(t, gjson.GetBytes(result, "messages.0.content.2.cache_control").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(result, "tools.0.cache_control").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
|
func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
|
||||||
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
@ -11,7 +12,7 @@ import (
|
|||||||
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
|
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
|
||||||
// 与 Parrot _strip_message_cache_control 语义一致。
|
// 与 Parrot _strip_message_cache_control 语义一致。
|
||||||
//
|
//
|
||||||
// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
|
// 旧策略为什么整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
|
||||||
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
|
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
|
||||||
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
|
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
|
||||||
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
|
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
|
||||||
@ -85,6 +86,25 @@ func addMessageCacheBreakpoints(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rewriteMessageCacheControlIfEnabled 按系统设置决定是否执行旧版 messages 缓存断点改写。
|
||||||
|
func (s *GatewayService) rewriteMessageCacheControlIfEnabled(ctx context.Context, body []byte) []byte {
|
||||||
|
if s == nil || !s.isRewriteMessageCacheControlEnabled(ctx) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
body = stripMessageCacheControl(body)
|
||||||
|
return addMessageCacheBreakpoints(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) isRewriteMessageCacheControlEnabled(ctx context.Context) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.settingService != nil {
|
||||||
|
return s.settingService.IsRewriteMessageCacheControlEnabled(ctx)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
|
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
|
||||||
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
|
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
|
||||||
// (对齐 Parrot _inject_cache_on_msg 的行为)。
|
// (对齐 Parrot _inject_cache_on_msg 的行为)。
|
||||||
|
|||||||
@ -1275,13 +1275,11 @@ func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
|
|||||||
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
|
body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
|
||||||
|
|
||||||
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
// Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
||||||
// 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
|
// 对齐 Parrot transform_request 里剩余的字段级改写。顺序有语义约束:
|
||||||
// 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性)
|
// 1) messages cache:仅在配置开启时清除客户端断点并注入代理断点
|
||||||
// 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn)
|
// 2) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
|
||||||
// 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
|
|
||||||
// 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
|
// 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
|
||||||
body = stripMessageCacheControl(body)
|
body = s.rewriteMessageCacheControlIfEnabled(ctx, body)
|
||||||
body = addMessageCacheBreakpoints(body)
|
|
||||||
|
|
||||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||||
body = applyToolNameRewriteToBody(body, rw)
|
body = applyToolNameRewriteToBody(body, rw)
|
||||||
@ -4166,7 +4164,7 @@ type cacheControlPath struct {
|
|||||||
log string
|
log string
|
||||||
}
|
}
|
||||||
|
|
||||||
func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) {
|
func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, toolPaths []string, systemPaths []string) {
|
||||||
system := gjson.GetBytes(body, "system")
|
system := gjson.GetBytes(body, "system")
|
||||||
if system.IsArray() {
|
if system.IsArray() {
|
||||||
sysIndex := 0
|
sysIndex := 0
|
||||||
@ -4215,17 +4213,29 @@ func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return invalidThinking, messagePaths, systemPaths
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
toolIndex := 0
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
if tool.Get("cache_control").Exists() {
|
||||||
|
toolPaths = append(toolPaths, fmt.Sprintf("tools.%d.cache_control", toolIndex))
|
||||||
|
}
|
||||||
|
toolIndex++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return invalidThinking, messagePaths, toolPaths, systemPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
// 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制
|
// 超限时优先移除工具断点,再移除 messages 断点,最后才移除 system 断点。
|
||||||
func enforceCacheControlLimit(body []byte) []byte {
|
func enforceCacheControlLimit(body []byte) []byte {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body)
|
invalidThinking, messagePaths, toolPaths, systemPaths := collectCacheControlPaths(body)
|
||||||
out := body
|
out := body
|
||||||
modified := false
|
modified := false
|
||||||
|
|
||||||
@ -4243,7 +4253,7 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
logger.LegacyPrintf("service.gateway", "%s", item.log)
|
logger.LegacyPrintf("service.gateway", "%s", item.log)
|
||||||
}
|
}
|
||||||
|
|
||||||
count := len(messagePaths) + len(systemPaths)
|
count := len(messagePaths) + len(toolPaths) + len(systemPaths)
|
||||||
if count <= maxCacheControlBlocks {
|
if count <= maxCacheControlBlocks {
|
||||||
if modified {
|
if modified {
|
||||||
return out
|
return out
|
||||||
@ -4251,8 +4261,22 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// 超限:优先从 messages 中移除,再从 system 中移除
|
// 超限:优先从 tools 中移除,再从 messages 中移除,最后才从 system 中移除。
|
||||||
remaining := count - maxCacheControlBlocks
|
remaining := count - maxCacheControlBlocks
|
||||||
|
for i := len(toolPaths) - 1; i >= 0 && remaining > 0; i-- {
|
||||||
|
path := toolPaths[i]
|
||||||
|
if !gjson.GetBytes(out, path).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next, ok := deleteJSONPathBytes(out, path)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
modified = true
|
||||||
|
remaining--
|
||||||
|
}
|
||||||
|
|
||||||
for _, path := range messagePaths {
|
for _, path := range messagePaths {
|
||||||
if remaining <= 0 {
|
if remaining <= 0 {
|
||||||
break
|
break
|
||||||
@ -4476,11 +4500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
|
|
||||||
// D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
// D/E/F: 可选 messages cache 策略 + 工具名混淆 + tools[-1] 断点
|
||||||
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
|
// 与 forward_as_chat_completions / forward_as_responses 路径对齐,
|
||||||
// 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
|
// 原生 /v1/messages 路径也走同一套可配置字段级改写。
|
||||||
body = stripMessageCacheControl(body)
|
body = s.rewriteMessageCacheControlIfEnabled(ctx, body)
|
||||||
body = addMessageCacheBreakpoints(body)
|
|
||||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||||
body = applyToolNameRewriteToBody(body, rw)
|
body = applyToolNameRewriteToBody(body, rw)
|
||||||
c.Set(toolNameRewriteKey, rw)
|
c.Set(toolNameRewriteKey, rw)
|
||||||
@ -8945,8 +8968,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
|
|
||||||
body = stripMessageCacheControl(body)
|
body = s.rewriteMessageCacheControlIfEnabled(ctx, body)
|
||||||
body = addMessageCacheBreakpoints(body)
|
|
||||||
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
if rw := buildToolNameRewriteFromBody(body); rw != nil {
|
||||||
body = applyToolNameRewriteToBody(body, rw)
|
body = applyToolNameRewriteToBody(body, rw)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -168,13 +168,13 @@ func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
|
// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
|
||||||
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
|
|
||||||
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
|
|
||||||
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
|
|
||||||
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
|
|
||||||
//
|
//
|
||||||
// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
|
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
|
||||||
// 响应侧 bytes.Replace 会连带还原它们。
|
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
|
||||||
|
// - 改写 $.messages[*].content[*].name(仅当 type == "tool_use")
|
||||||
|
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点
|
||||||
|
//
|
||||||
|
// 响应侧 bytes.Replace 会连带还原假名 → 真名。
|
||||||
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
|
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
|
||||||
if rw == nil || len(rw.Forward) == 0 {
|
if rw == nil || len(rw.Forward) == 0 {
|
||||||
body = applyToolsLastCacheBreakpoint(body)
|
body = applyToolsLastCacheBreakpoint(body)
|
||||||
@ -213,6 +213,37 @@ func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 同步改写历史消息中的 tool_use.name,确保它和 tools[] 中的假名一致。
|
||||||
|
// 否则 Anthropic 会因为 tool_use 引用了未声明的原始工具名而拒绝请求。
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if messages.IsArray() {
|
||||||
|
messages.ForEach(func(msgKey, msg gjson.Result) bool {
|
||||||
|
msgIdx := int(msgKey.Num)
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.IsArray() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
content.ForEach(func(blkKey, blk gjson.Result) bool {
|
||||||
|
blkIdx := int(blkKey.Num)
|
||||||
|
if blk.Get("type").String() != "tool_use" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
name := blk.Get("name").String()
|
||||||
|
if name == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if fake, ok := rw.Forward[name]; ok {
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.name", msgIdx, blkIdx)
|
||||||
|
if next, err := sjson.SetBytes(body, path, fake); err == nil {
|
||||||
|
body = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
body = applyToolsLastCacheBreakpoint(body)
|
body = applyToolsLastCacheBreakpoint(body)
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@ -69,21 +71,68 @@ func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
|
|||||||
require.NotNil(t, rw)
|
require.NotNil(t, rw)
|
||||||
require.Contains(t, rw.Forward, "sessions_list")
|
require.Contains(t, rw.Forward, "sessions_list")
|
||||||
require.Contains(t, rw.Forward, "session_get")
|
require.Contains(t, rw.Forward, "session_get")
|
||||||
// web_search is a server tool, not rewritten
|
// web_search 是 server tool,不参与工具名改写
|
||||||
require.NotContains(t, rw.Forward, "web_search")
|
require.NotContains(t, rw.Forward, "web_search")
|
||||||
|
|
||||||
out := applyToolNameRewriteToBody(body, rw)
|
out := applyToolNameRewriteToBody(body, rw)
|
||||||
|
|
||||||
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
|
// tools[0].name 和 tools[1].name 被改写,tools[2].name 保持不变
|
||||||
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
|
||||||
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
|
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
|
||||||
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
|
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
|
||||||
|
|
||||||
// tool_choice.name rewritten
|
// tool_choice.name 被同步改写
|
||||||
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
|
||||||
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
|
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyToolNameRewriteToBody_RenamesToolUseInMessages(t *testing.T) {
|
||||||
|
// sessions_list 通过静态前缀规则改写为 cc_sess_list
|
||||||
|
// web_search 是 server tool(type != ""),不参与工具名改写
|
||||||
|
// messages 中的 tool_use.name 必须同步改写,才能和 tools[] 保持一致
|
||||||
|
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"assistant","content":[{"type":"tool_use","id":"tu_01","name":"sessions_list","input":{}},{"type":"text","text":"thinking"}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu_01","content":"ok"}]}]}`)
|
||||||
|
rw := buildToolNameRewriteFromBody(body)
|
||||||
|
require.NotNil(t, rw)
|
||||||
|
require.Equal(t, "cc_sess_list", rw.Forward["sessions_list"])
|
||||||
|
|
||||||
|
out := applyToolNameRewriteToBody(body, rw)
|
||||||
|
|
||||||
|
// tools[0].name 被改写
|
||||||
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
|
||||||
|
// tools[1].name 是 server tool,保持不变
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.1.name").String())
|
||||||
|
// messages[1].content[0].name 是 tool_use,必须同步改写以匹配 tools[]
|
||||||
|
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "messages.1.content.0.name").String())
|
||||||
|
// messages[1].content[1] 是 text,保持不变
|
||||||
|
require.Equal(t, "thinking", gjson.GetBytes(out, "messages.1.content.1.text").String())
|
||||||
|
// messages[2].content[0] 是 tool_result,不包含 name 字段,保持不变
|
||||||
|
require.Equal(t, "ok", gjson.GetBytes(out, "messages.2.content.0.content").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToolNameRewriteToBody_RenamesToolUseWithDynamicMapping(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"alpha_search","input_schema":{}},{"name":"beta_lookup","input_schema":{}},{"name":"gamma_fetch","input_schema":{}},{"name":"delta_update","input_schema":{}},{"name":"epsilon_parse","input_schema":{}},{"name":"zeta_render","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"gamma_fetch"},"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"tu_dyn","name":"gamma_fetch","input":{}},{"type":"tool_use","id":"tu_srv","name":"web_search","input":{}},{"type":"text","text":"done"}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu_dyn","content":"ok"}]}]}`)
|
||||||
|
rw := buildToolNameRewriteFromBody(body)
|
||||||
|
require.NotNil(t, rw)
|
||||||
|
require.Len(t, rw.Forward, 6)
|
||||||
|
|
||||||
|
fakeGamma := rw.Forward["gamma_fetch"]
|
||||||
|
require.NotEmpty(t, fakeGamma)
|
||||||
|
require.NotEqual(t, "gamma_fetch", fakeGamma)
|
||||||
|
require.NotContains(t, rw.Forward, "web_search")
|
||||||
|
|
||||||
|
out := applyToolNameRewriteToBody(body, rw)
|
||||||
|
|
||||||
|
// 动态映射会改写 tools[]、tool_choice 和历史 tool_use 中的同一个工具名
|
||||||
|
require.Equal(t, fakeGamma, gjson.GetBytes(out, "tools.2.name").String())
|
||||||
|
require.Equal(t, fakeGamma, gjson.GetBytes(out, "tool_choice.name").String())
|
||||||
|
require.Equal(t, fakeGamma, gjson.GetBytes(out, "messages.0.content.0.name").String())
|
||||||
|
// server tool 不参与动态映射,历史 tool_use 中同名引用也保持不变
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.6.name").String())
|
||||||
|
require.Equal(t, "web_search", gjson.GetBytes(out, "messages.0.content.1.name").String())
|
||||||
|
// tool_result 依靠 tool_use_id 关联,不需要 name 字段
|
||||||
|
require.Equal(t, "ok", gjson.GetBytes(out, "messages.1.content.0.content").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
|
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
|
||||||
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
|
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
|
||||||
out := applyToolsLastCacheBreakpoint(body)
|
out := applyToolsLastCacheBreakpoint(body)
|
||||||
@ -141,6 +190,40 @@ func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
|
|||||||
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRewriteMessageCacheControlIfEnabled_DefaultKeepsClientAnchors(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"stable","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"ok"}]},
|
||||||
|
{"role":"user","content":[{"type":"text","text":"latest","cache_control":{"type":"ephemeral","ttl":"5m"}}]}
|
||||||
|
]}`)
|
||||||
|
|
||||||
|
out := (&GatewayService{}).rewriteMessageCacheControlIfEnabled(context.Background(), body)
|
||||||
|
|
||||||
|
require.JSONEq(t, string(body), string(out))
|
||||||
|
require.Equal(t, "1h", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.2.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMessageCacheControlIfEnabled_OptInPreservesLegacyRewrite(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"stable","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"ok"}]},
|
||||||
|
{"role":"user","content":[{"type":"text","text":"latest","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"done"}]}
|
||||||
|
]}`)
|
||||||
|
repo := &gatewayTTLSettingRepo{data: map[string]string{
|
||||||
|
SettingKeyRewriteMessageCacheControl: "true",
|
||||||
|
}}
|
||||||
|
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
|
||||||
|
svc := &GatewayService{settingService: NewSettingService(repo, &config.Config{})}
|
||||||
|
|
||||||
|
out := svc.rewriteMessageCacheControlIfEnabled(context.Background(), body)
|
||||||
|
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
|
||||||
|
require.Equal(t, "5m", gjson.GetBytes(out, "messages.3.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
|
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
|
||||||
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
|
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
|
||||||
body := []byte(`{"tools":[
|
body := []byte(`{"tools":[
|
||||||
|
|||||||
@ -1514,7 +1514,7 @@ func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_SingleCandidate(t *tes
|
|||||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, selection)
|
require.NotNil(t, selection)
|
||||||
require.Equal(t, int64(71001), selection.Account.ID)
|
require.Equal(t, int64(71001), selection.Account.ID)
|
||||||
@ -1541,7 +1541,7 @@ func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_PicksBetterCandidate(t
|
|||||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, selection)
|
require.NotNil(t, selection)
|
||||||
require.NotNil(t, selection.Account)
|
require.NotNil(t, selection.Account)
|
||||||
|
|||||||
@ -480,6 +480,61 @@ func TestForwardAsAnthropic_AttachesPreviousResponseIDForCompatContinuation(t *t
|
|||||||
require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String())
|
require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_PreviousResponseIDKeepsMultiToolCallContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
httpUpstream: upstream,
|
||||||
|
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"}],"stream":false}`)
|
||||||
|
upstream.resp = openAICompatSSECompletedResponse("resp_first_tools", "gpt-5.3-codex")
|
||||||
|
firstRec := httptest.NewRecorder()
|
||||||
|
firstCtx, _ := gin.CreateTestContext(firstRec)
|
||||||
|
firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody))
|
||||||
|
firstCtx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, firstResult)
|
||||||
|
|
||||||
|
secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"},{"role":"assistant","content":[{"type":"tool_use","id":"call_one","name":"Read","input":{"file_path":"a.go"}},{"type":"tool_use","id":"call_two","name":"Read","input":{"file_path":"b.go"}}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"call_one","content":"package a"},{"type":"tool_result","tool_use_id":"call_two","content":"package b"},{"type":"text","text":"continue"}]}],"tools":[{"name":"Read","description":"read a file","input_schema":{"type":"object","properties":{"file_path":{"type":"string"}}}}],"stream":false}`)
|
||||||
|
upstream.resp = openAICompatSSECompletedResponse("resp_second_tools", "gpt-5.3-codex")
|
||||||
|
secondRec := httptest.NewRecorder()
|
||||||
|
secondCtx, _ := gin.CreateTestContext(secondRec)
|
||||||
|
secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody))
|
||||||
|
secondCtx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, secondResult)
|
||||||
|
require.Equal(t, "resp_first_tools", gjson.GetBytes(upstream.lastBody, "previous_response_id").String())
|
||||||
|
|
||||||
|
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.1.type").String())
|
||||||
|
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.1.call_id").String())
|
||||||
|
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.2.type").String())
|
||||||
|
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.2.call_id").String())
|
||||||
|
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.3.type").String())
|
||||||
|
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.3.call_id").String())
|
||||||
|
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.4.type").String())
|
||||||
|
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.4.call_id").String())
|
||||||
|
require.Equal(t, "continue", gjson.GetBytes(upstream.lastBody, "input.5.content.0.text").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) {
|
func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|||||||
@ -242,6 +242,57 @@ func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing
|
|||||||
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_MissingPricingRecordsZeroCostUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_missing_pricing",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 1200,
|
||||||
|
OutputTokens: 300,
|
||||||
|
},
|
||||||
|
Model: "deepseek-v4-flash",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1002, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||||
|
User: &User{ID: 2002},
|
||||||
|
Account: &Account{ID: 3002, Type: AccountTypeAPIKey},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||||
|
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "resp_missing_pricing", usageRepo.lastLog.RequestID)
|
||||||
|
require.Equal(t, "deepseek-v4-flash", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, "deepseek-v4-flash", usageRepo.lastLog.RequestedModel)
|
||||||
|
require.Equal(t, 1200, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Equal(t, 300, usageRepo.lastLog.OutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, string(BillingModeToken), *usageRepo.lastLog.BillingMode)
|
||||||
|
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||||
groupID := int64(11)
|
groupID := int64(11)
|
||||||
groupRate := 1.4
|
groupRate := 1.4
|
||||||
@ -1157,7 +1208,7 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpr
|
|||||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UnpricedTokenModelFallsBackToZeroCostUsageLog(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
@ -1175,9 +1226,14 @@ func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePrice
|
|||||||
Account: &Account{ID: 30},
|
Account: &Account{ID: 30},
|
||||||
})
|
})
|
||||||
|
|
||||||
require.Error(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
require.Equal(t, 0, usageRepo.calls)
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "not-priceable-alias", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, 20, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Equal(t, 10, usageRepo.lastLog.OutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||||
require.Equal(t, 0, userRepo.deductCalls)
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
require.Equal(t, 0, subRepo.incrementCalls)
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5273,7 +5273,19 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
}
|
}
|
||||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if !isUsagePricingUnavailableError(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "service.openai_gateway"),
|
||||||
|
zap.Strings("billing_models", billingModels),
|
||||||
|
zap.String("requested_model", input.OriginalModel),
|
||||||
|
zap.String("mapped_model", input.ChannelMappedModel),
|
||||||
|
zap.String("upstream_model", result.UpstreamModel),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Warn("openai_usage.pricing_missing_record_zero_cost", zap.Error(err))
|
||||||
|
cost = &CostBreakdown{BillingMode: string(BillingModeToken)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine billing type
|
// Determine billing type
|
||||||
@ -5439,6 +5451,17 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
|||||||
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUsagePricingUnavailableError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrModelPricingUnavailable) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "no pricing available") || strings.Contains(msg, "pricing not found")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
apiKey *APIKey,
|
apiKey *APIKey,
|
||||||
|
|||||||
@ -37,10 +37,7 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
start := len(items) - 1
|
start := latestAnthropicCompatResponsesInputTurnStart(items)
|
||||||
for start > 0 && items[start].Type == "function_call_output" {
|
|
||||||
start--
|
|
||||||
}
|
|
||||||
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
|
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
|
||||||
if len(trimmed) == len(items) {
|
if len(trimmed) == len(items) {
|
||||||
return
|
return
|
||||||
@ -50,6 +47,63 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func latestAnthropicCompatResponsesInputTurnStart(items []apicompat.ResponsesInputItem) int {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
start := len(items) - 1
|
||||||
|
last := items[start]
|
||||||
|
switch {
|
||||||
|
case last.Type == "function_call_output":
|
||||||
|
for start > 0 && items[start-1].Type == "function_call_output" {
|
||||||
|
start--
|
||||||
|
}
|
||||||
|
case last.Type == "message" && last.Role == "user":
|
||||||
|
for start > 0 && items[start-1].Type == "function_call_output" {
|
||||||
|
start--
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return start
|
||||||
|
}
|
||||||
|
|
||||||
|
return expandAnthropicCompatResponsesInputToolCallStart(items, start)
|
||||||
|
}
|
||||||
|
|
||||||
|
func expandAnthropicCompatResponsesInputToolCallStart(items []apicompat.ResponsesInputItem, start int) int {
|
||||||
|
if start < 0 || start >= len(items) {
|
||||||
|
return start
|
||||||
|
}
|
||||||
|
|
||||||
|
needed := make(map[string]struct{})
|
||||||
|
for i := start; i < len(items); i++ {
|
||||||
|
if items[i].Type != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := strings.TrimSpace(items[i].CallID)
|
||||||
|
if callID != "" {
|
||||||
|
needed[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(needed) == 0 {
|
||||||
|
return start
|
||||||
|
}
|
||||||
|
|
||||||
|
expandedStart := start
|
||||||
|
for i := start - 1; i >= 0 && len(needed) > 0; i-- {
|
||||||
|
if items[i].Type != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := strings.TrimSpace(items[i].CallID)
|
||||||
|
if _, ok := needed[callID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(needed, callID)
|
||||||
|
expandedStart = i
|
||||||
|
}
|
||||||
|
return expandedStart
|
||||||
|
}
|
||||||
|
|
||||||
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||||
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
|
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -1552,6 +1552,15 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
|
||||||
|
for _, item := range items {
|
||||||
|
if gjson.GetBytes(item, "type").String() == "function_call_output" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func buildOpenAIWSReplayInputSequence(
|
func buildOpenAIWSReplayInputSequence(
|
||||||
previousFullInput []json.RawMessage,
|
previousFullInput []json.RawMessage,
|
||||||
previousFullInputExists bool,
|
previousFullInputExists bool,
|
||||||
@ -3117,6 +3126,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
currentTurnReplayInput := []json.RawMessage(nil)
|
currentTurnReplayInput := []json.RawMessage(nil)
|
||||||
currentTurnReplayInputExists := false
|
currentTurnReplayInputExists := false
|
||||||
skipBeforeTurn := false
|
skipBeforeTurn := false
|
||||||
|
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
|
||||||
|
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
|
||||||
|
}
|
||||||
resetSessionLease := func(markBroken bool) {
|
resetSessionLease := func(markBroken bool) {
|
||||||
if sessionLease == nil {
|
if sessionLease == nil {
|
||||||
return
|
return
|
||||||
@ -3139,7 +3154,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||||
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
|
if hasCurrentOrReplayFunctionCallOutput(currentPayload) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if isStrictAffinityTurn(currentPayload) {
|
if isStrictAffinityTurn(currentPayload) {
|
||||||
@ -3298,6 +3313,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
currentTurnReplayInput = nextReplayInput
|
currentTurnReplayInput = nextReplayInput
|
||||||
currentTurnReplayInputExists = nextReplayInputExists
|
currentTurnReplayInputExists = nextReplayInputExists
|
||||||
}
|
}
|
||||||
|
replayHasFunctionCallOutput := currentTurnReplayInputExists &&
|
||||||
|
openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
|
||||||
|
hasFunctionCallOutput = hasFunctionCallOutput || replayHasFunctionCallOutput
|
||||||
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
|
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
|
||||||
shouldKeepPreviousResponseID := false
|
shouldKeepPreviousResponseID := false
|
||||||
strictReason := ""
|
strictReason := ""
|
||||||
@ -3416,7 +3434,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||||
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
hasFCOutput := hasFunctionCallOutput
|
||||||
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
|
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
|
||||||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||||||
if dropErr != nil || !removed {
|
if dropErr != nil || !removed {
|
||||||
@ -3464,6 +3482,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if hasFCOutput && currentPreviousResponseID != "" {
|
||||||
|
logOpenAIWSModeInfo(
|
||||||
|
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s",
|
||||||
|
account.ID,
|
||||||
|
turn,
|
||||||
|
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||||||
|
)
|
||||||
|
}
|
||||||
resetSessionLease(true)
|
resetSessionLease(true)
|
||||||
return NewOpenAIWSClientCloseError(
|
return NewOpenAIWSClientCloseError(
|
||||||
coderws.StatusPolicyViolation,
|
coderws.StatusPolicyViolation,
|
||||||
|
|||||||
@ -1918,6 +1918,298 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr
|
|||||||
require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
|
require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||||
|
openAIWSIngressPreflightPingIdle = 0
|
||||||
|
defer func() {
|
||||||
|
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
firstConn := &openAIWSPreflightFailConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"Previous response not found."}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &openAIWSQueueDialer{
|
||||||
|
conns: []openAIWSClientConn{firstConn, secondConn},
|
||||||
|
}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(dialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 129,
|
||||||
|
Name: "openai-ingress-preflight-replay-function-output",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeMessage := func(payload string) {
|
||||||
|
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||||
|
}
|
||||||
|
readMessage := func() []byte {
|
||||||
|
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
msgType, message, readErr := clientConn.Read(readCtx)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, coderws.MessageText, msgType)
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||||
|
firstTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_turn_ping_replay_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_fc_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.Error(t, serverErr)
|
||||||
|
var closeErr *OpenAIWSClientCloseError
|
||||||
|
require.ErrorAs(t, serverErr, &closeErr)
|
||||||
|
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
|
||||||
|
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, dialer.DialCount(), "需要 previous_response_id 的 function_call_output 在原连接不可用时不应换新连接重试")
|
||||||
|
secondConn.mu.Lock()
|
||||||
|
secondWrites := append([]map[string]any(nil), secondConn.writes...)
|
||||||
|
secondConn.mu.Unlock()
|
||||||
|
require.Empty(t, secondWrites, "不能把旧连接的 previous_response_id 发送到新上游,否则会触发 previous_response_not_found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenReplayHasFunctionCallOutput(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||||
|
openAIWSIngressPreflightPingIdle = 0
|
||||||
|
defer func() {
|
||||||
|
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
firstConn := &openAIWSPreflightFailConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_only_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found for function call output with call_id call_replay_1.","param":"input"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &openAIWSQueueDialer{
|
||||||
|
conns: []openAIWSClientConn{firstConn, secondConn},
|
||||||
|
}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(dialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 130,
|
||||||
|
Name: "openai-ingress-preflight-replay-only-function-output",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeMessage := func(payload string) {
|
||||||
|
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||||
|
}
|
||||||
|
readMessage := func() []byte {
|
||||||
|
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
msgType, message, readErr := clientConn.Read(readCtx)
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, coderws.MessageText, msgType)
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||||
|
firstTurn := readMessage()
|
||||||
|
require.Equal(t, "resp_turn_ping_replay_only_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||||
|
|
||||||
|
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_only_fc_1","input":[{"type":"input_text","text":"after tool output"}]}`)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.Error(t, serverErr)
|
||||||
|
var closeErr *OpenAIWSClientCloseError
|
||||||
|
require.ErrorAs(t, serverErr, &closeErr)
|
||||||
|
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
|
||||||
|
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, dialer.DialCount(), "replay input 带 function_call_output 时不应换新连接重试")
|
||||||
|
secondConn.mu.Lock()
|
||||||
|
secondWrites := append([]map[string]any(nil), secondConn.writes...)
|
||||||
|
secondConn.mu.Unlock()
|
||||||
|
require.Empty(t, secondWrites, "不能把会触发 No tool call found 的重放请求发到新上游")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,16 +23,17 @@ func calculateCreditedBalance(paymentAmount, multiplier float64) float64 {
|
|||||||
InexactFloat64()
|
InexactFloat64()
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64) float64 {
|
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64, currency string) float64 {
|
||||||
if orderAmount <= 0 || payAmount <= 0 || refundAmount <= 0 {
|
if orderAmount <= 0 || payAmount <= 0 || refundAmount <= 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
if math.Abs(refundAmount-orderAmount) <= amountToleranceCNY {
|
fractionDigits := int32(payment.CurrencyMaxFractionDigits(currency))
|
||||||
return decimal.NewFromFloat(payAmount).Round(2).InexactFloat64()
|
if math.Abs(refundAmount-orderAmount) <= paymentAmountToleranceForCurrency(currency) {
|
||||||
|
return decimal.NewFromFloat(payAmount).Round(fractionDigits).InexactFloat64()
|
||||||
}
|
}
|
||||||
return decimal.NewFromFloat(payAmount).
|
return decimal.NewFromFloat(payAmount).
|
||||||
Mul(decimal.NewFromFloat(refundAmount)).
|
Mul(decimal.NewFromFloat(refundAmount)).
|
||||||
Div(decimal.NewFromFloat(orderAmount)).
|
Div(decimal.NewFromFloat(orderAmount)).
|
||||||
Round(2).
|
Round(fractionDigits).
|
||||||
InexactFloat64()
|
InexactFloat64()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAvailableMethodLimits collects all payment types from enabled provider
|
// GetAvailableMethodLimits collects all payment types from enabled provider
|
||||||
@ -25,7 +26,12 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
|||||||
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
||||||
}
|
}
|
||||||
for pt, insts := range typeInstances {
|
for pt, insts := range typeInstances {
|
||||||
|
currency, ok := s.pcAggregateMethodCurrency(insts)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
ml := pcAggregateMethodLimits(pt, insts)
|
ml := pcAggregateMethodLimits(pt, insts)
|
||||||
|
ml.Currency = currency
|
||||||
resp.Methods[ml.PaymentType] = ml
|
resp.Methods[ml.PaymentType] = ml
|
||||||
}
|
}
|
||||||
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
|
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
|
||||||
@ -82,11 +88,81 @@ func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []stri
|
|||||||
matching = append(matching, inst)
|
matching = append(matching, inst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result = append(result, pcAggregateMethodLimits(pt, matching))
|
currency, ok := s.pcAggregateMethodCurrency(matching)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ml := pcAggregateMethodLimits(pt, matching)
|
||||||
|
ml.Currency = currency
|
||||||
|
result = append(result, ml)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PaymentConfigService) ValidateMethodCurrencyConsistency(ctx context.Context, paymentType string) (string, error) {
|
||||||
|
method := NormalizeVisibleMethod(paymentType)
|
||||||
|
if method == "" || s == nil || s.entClient == nil {
|
||||||
|
return payment.DefaultPaymentCurrency, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||||
|
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("query provider instances: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
typeInstances := pcGroupByPaymentType(instances)
|
||||||
|
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
|
||||||
|
matching := typeInstances[method]
|
||||||
|
if len(matching) == 0 {
|
||||||
|
return payment.DefaultPaymentCurrency, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
currency, ok := s.pcAggregateMethodCurrency(matching)
|
||||||
|
if !ok {
|
||||||
|
return "", infraerrors.ServiceUnavailable(
|
||||||
|
"PAYMENT_METHOD_CURRENCY_CONFLICT",
|
||||||
|
"payment method has enabled provider instances with mixed currencies",
|
||||||
|
).WithMetadata(map[string]string{"payment_type": method})
|
||||||
|
}
|
||||||
|
return currency, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PaymentConfigService) pcAggregateMethodCurrency(instances []*dbent.PaymentProviderInstance) (string, bool) {
|
||||||
|
currency := ""
|
||||||
|
for _, inst := range instances {
|
||||||
|
next := s.pcInstancePaymentCurrency(inst)
|
||||||
|
if next == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if currency == "" {
|
||||||
|
currency = next
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if currency != next {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currency == "" {
|
||||||
|
return payment.DefaultPaymentCurrency, true
|
||||||
|
}
|
||||||
|
return currency, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PaymentConfigService) pcInstancePaymentCurrency(inst *dbent.PaymentProviderInstance) string {
|
||||||
|
if inst == nil {
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
cfg := map[string]string{}
|
||||||
|
if s != nil {
|
||||||
|
decrypted, err := s.decryptConfig(inst.Config)
|
||||||
|
if err == nil && decrypted != nil {
|
||||||
|
cfg = decrypted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return paymentProviderConfigCurrency(inst.ProviderKey, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
// pcGroupByPaymentType groups instances by user-facing payment type.
|
// pcGroupByPaymentType groups instances by user-facing payment type.
|
||||||
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
|
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
|
||||||
// because the user sees a single "Stripe" button, not individual sub-methods.
|
// because the user sees a single "Stripe" button, not individual sub-methods.
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,6 +200,61 @@ func TestPcGroupByPaymentType(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPcAggregateMethodCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
svc := &PaymentConfigService{}
|
||||||
|
stripe := makeInstance(1, payment.TypeStripe, payment.TypeStripe, "")
|
||||||
|
stripe.Config = `{"currency":"hkd"}`
|
||||||
|
currency, ok := svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe})
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "HKD", currency)
|
||||||
|
|
||||||
|
airwallex := makeInstance(2, payment.TypeAirwallex, payment.TypeAirwallex, "")
|
||||||
|
airwallex.Config = `{"currency":"usd"}`
|
||||||
|
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe, airwallex})
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Empty(t, currency)
|
||||||
|
|
||||||
|
easypay := makeInstance(3, payment.TypeEasyPay, payment.TypeAlipay, "")
|
||||||
|
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{easypay})
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, payment.DefaultPaymentCurrency, currency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableMethodLimitsOmitsMixedCurrencyMethod(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
|
||||||
|
_, err := client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeStripe).
|
||||||
|
SetName("Stripe HKD").
|
||||||
|
SetConfig(`{"currency":"HKD"}`).
|
||||||
|
SetSupportedTypes("card,link").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeStripe).
|
||||||
|
SetName("Stripe USD").
|
||||||
|
SetConfig(`{"currency":"USD"}`).
|
||||||
|
SetSupportedTypes("card,link").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &PaymentConfigService{entClient: client}
|
||||||
|
resp, err := svc.GetAvailableMethodLimits(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, resp.Methods, payment.TypeStripe)
|
||||||
|
|
||||||
|
_, err = svc.ValidateMethodCurrencyConsistency(ctx, payment.TypeStripe)
|
||||||
|
require.Error(t, err)
|
||||||
|
appErr := infraerrors.FromError(err)
|
||||||
|
require.Equal(t, "PAYMENT_METHOD_CURRENCY_CONFLICT", appErr.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPcComputeGlobalRange(t *testing.T) {
|
func TestPcComputeGlobalRange(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@ -110,10 +110,11 @@ var pendingOrderStatuses = []string{
|
|||||||
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
|
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
|
||||||
// stripe publishableKey) are returned in plaintext by the admin GET API.
|
// stripe publishableKey) are returned in plaintext by the admin GET API.
|
||||||
var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
||||||
payment.TypeEasyPay: {"pkey": {}},
|
payment.TypeEasyPay: {"pkey": {}},
|
||||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
|
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
|
||||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
|
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
|
||||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
||||||
|
payment.TypeAirwallex: {"apikey": {}, "webhooksecret": {}},
|
||||||
}
|
}
|
||||||
|
|
||||||
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
|
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
|
||||||
@ -121,10 +122,11 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
|||||||
// all provider identity fields that are snapshotted into orders or used by
|
// all provider identity fields that are snapshotted into orders or used by
|
||||||
// webhook/refund verification.
|
// webhook/refund verification.
|
||||||
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
|
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
|
||||||
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
|
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
|
||||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
|
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
|
||||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
|
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
|
||||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}, "currency": {}},
|
||||||
|
payment.TypeAirwallex: {"clientid": {}, "apikey": {}, "webhooksecret": {}, "apibase": {}, "accountid": {}, "currency": {}},
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
|
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
|
||||||
@ -175,7 +177,7 @@ func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, pla
|
|||||||
}
|
}
|
||||||
|
|
||||||
var validProviderKeys = map[string]bool{
|
var validProviderKeys = map[string]bool{
|
||||||
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
|
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true, payment.TypeAirwallex: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
||||||
|
|||||||
@ -44,6 +44,13 @@ func TestValidateProviderRequest(t *testing.T) {
|
|||||||
supportedTypes: "",
|
supportedTypes: "",
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "valid airwallex provider",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
providerName: "Airwallex Provider",
|
||||||
|
supportedTypes: payment.TypeAirwallex,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "valid alipay provider",
|
name: "valid alipay provider",
|
||||||
providerKey: "alipay",
|
providerKey: "alipay",
|
||||||
@ -120,6 +127,7 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
|
|||||||
{"stripe", "webhookSecret", true},
|
{"stripe", "webhookSecret", true},
|
||||||
{"stripe", "SecretKey", true}, // case-insensitive
|
{"stripe", "SecretKey", true}, // case-insensitive
|
||||||
{"stripe", "publishableKey", false},
|
{"stripe", "publishableKey", false},
|
||||||
|
{"stripe", "currency", false},
|
||||||
{"stripe", "appId", false},
|
{"stripe", "appId", false},
|
||||||
|
|
||||||
// Alipay
|
// Alipay
|
||||||
@ -142,6 +150,14 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
|
|||||||
{"easypay", "pid", false},
|
{"easypay", "pid", false},
|
||||||
{"easypay", "apiBase", false},
|
{"easypay", "apiBase", false},
|
||||||
|
|
||||||
|
// Airwallex
|
||||||
|
{payment.TypeAirwallex, "apiKey", true},
|
||||||
|
{payment.TypeAirwallex, "webhookSecret", true},
|
||||||
|
{payment.TypeAirwallex, "clientId", false},
|
||||||
|
{payment.TypeAirwallex, "apiBase", false},
|
||||||
|
{payment.TypeAirwallex, "accountId", false},
|
||||||
|
{payment.TypeAirwallex, "currency", false},
|
||||||
|
|
||||||
// Unknown provider: never sensitive
|
// Unknown provider: never sensitive
|
||||||
{"unknown", "secretKey", false},
|
{"unknown", "secretKey", false},
|
||||||
}
|
}
|
||||||
@ -395,6 +411,42 @@ func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t
|
|||||||
fieldName: "pid",
|
fieldName: "pid",
|
||||||
wantValue: "pid-test",
|
wantValue: "pid-test",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "stripe currency",
|
||||||
|
providerKey: payment.TypeStripe,
|
||||||
|
createConfig: validStripeProviderConfig,
|
||||||
|
supportedType: []string{payment.TypeStripe},
|
||||||
|
updateConfig: map[string]string{"currency": "HKD"},
|
||||||
|
fieldName: "currency",
|
||||||
|
wantValue: "CNY",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "airwallex accountId",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
createConfig: validAirwallexProviderConfig,
|
||||||
|
supportedType: []string{payment.TypeAirwallex},
|
||||||
|
updateConfig: map[string]string{"accountId": "acct-updated"},
|
||||||
|
fieldName: "accountId",
|
||||||
|
wantValue: "acct-test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "airwallex currency",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
createConfig: validAirwallexProviderConfig,
|
||||||
|
supportedType: []string{payment.TypeAirwallex},
|
||||||
|
updateConfig: map[string]string{"currency": "HKD"},
|
||||||
|
fieldName: "currency",
|
||||||
|
wantValue: "CNY",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "airwallex webhookSecret",
|
||||||
|
providerKey: payment.TypeAirwallex,
|
||||||
|
createConfig: validAirwallexProviderConfig,
|
||||||
|
supportedType: []string{payment.TypeAirwallex},
|
||||||
|
updateConfig: map[string]string{"webhookSecret": "whsec-updated"},
|
||||||
|
fieldName: "webhookSecret",
|
||||||
|
wantValue: "whsec-test",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@ -506,6 +558,39 @@ func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *test
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateProviderInstanceClearsAirwallexAccountID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
svc := &PaymentConfigService{
|
||||||
|
entClient: client,
|
||||||
|
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||||
|
}
|
||||||
|
|
||||||
|
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||||
|
ProviderKey: payment.TypeAirwallex,
|
||||||
|
Name: "airwallex-clear-account",
|
||||||
|
Config: validAirwallexProviderConfig(t),
|
||||||
|
SupportedTypes: []string{payment.TypeAirwallex},
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
|
||||||
|
Config: map[string]string{"accountId": ""},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
|
||||||
|
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
cfg, err := svc.decryptConfig(saved.Config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, cfg["accountId"])
|
||||||
|
require.Equal(t, "client-id-test", cfg["clientId"])
|
||||||
|
}
|
||||||
|
|
||||||
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
|
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@ -545,11 +630,26 @@ func providerPendingOrderPaymentType(providerKey string) string {
|
|||||||
return payment.TypeWxpay
|
return payment.TypeWxpay
|
||||||
case payment.TypeAlipay:
|
case payment.TypeAlipay:
|
||||||
return payment.TypeAlipay
|
return payment.TypeAlipay
|
||||||
|
case payment.TypeAirwallex:
|
||||||
|
return payment.TypeAirwallex
|
||||||
|
case payment.TypeStripe:
|
||||||
|
return payment.TypeStripe
|
||||||
default:
|
default:
|
||||||
return payment.TypeAlipay
|
return payment.TypeAlipay
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validStripeProviderConfig(t *testing.T) map[string]string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return map[string]string{
|
||||||
|
"secretKey": "sk_test_123",
|
||||||
|
"publishableKey": "pk_test_123",
|
||||||
|
"webhookSecret": "whsec-test",
|
||||||
|
"currency": "CNY",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func boolPtrValue(v bool) *bool {
|
func boolPtrValue(v bool) *bool {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
@ -577,6 +677,19 @@ func validEasyPayProviderConfig(t *testing.T) map[string]string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validAirwallexProviderConfig(t *testing.T) map[string]string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return map[string]string{
|
||||||
|
"clientId": "client-id-test",
|
||||||
|
"apiKey": "api-key-test",
|
||||||
|
"webhookSecret": "whsec-test",
|
||||||
|
"apiBase": "https://api-demo.airwallex.com/api/v1",
|
||||||
|
"accountId": "acct-test",
|
||||||
|
"currency": "CNY",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func validWxpayProviderConfig(t *testing.T) map[string]string {
|
func validWxpayProviderConfig(t *testing.T) map[string]string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@ -103,6 +103,7 @@ type UpdatePaymentConfigRequest struct {
|
|||||||
// MethodLimits holds per-payment-type limits.
|
// MethodLimits holds per-payment-type limits.
|
||||||
type MethodLimits struct {
|
type MethodLimits struct {
|
||||||
PaymentType string `json:"payment_type"`
|
PaymentType string `json:"payment_type"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
FeeRate float64 `json:"fee_rate"`
|
FeeRate float64 `json:"fee_rate"`
|
||||||
DailyLimit float64 `json:"daily_limit"`
|
DailyLimit float64 `json:"daily_limit"`
|
||||||
SingleMin float64 `json:"single_min"`
|
SingleMin float64 `json:"single_min"`
|
||||||
|
|||||||
28
backend/internal/service/payment_currency.go
Normal file
28
backend/internal/service/payment_currency.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
)
|
||||||
|
|
||||||
|
func paymentProviderConfigCurrency(providerKey string, cfg map[string]string) string {
|
||||||
|
switch strings.TrimSpace(providerKey) {
|
||||||
|
case payment.TypeStripe, payment.TypeAirwallex:
|
||||||
|
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
|
||||||
|
if err == nil {
|
||||||
|
return currency
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
|
|
||||||
|
func PaymentOrderCurrency(order *dbent.PaymentOrder) string {
|
||||||
|
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
|
||||||
|
if currency, err := payment.NormalizePaymentCurrency(snapshot.Currency); err == nil {
|
||||||
|
return currency
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return payment.DefaultPaymentCurrency
|
||||||
|
}
|
||||||
@ -101,13 +101,21 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
|
|||||||
})
|
})
|
||||||
return fmt.Errorf("invalid paid amount from provider: %v", paid)
|
return fmt.Errorf("invalid paid amount from provider: %v", paid)
|
||||||
}
|
}
|
||||||
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
|
if math.Abs(paid-o.PayAmount) > paymentAmountToleranceForCurrency(PaymentOrderCurrency(o)) {
|
||||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
|
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
|
||||||
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
|
return fmt.Errorf("amount mismatch: expected %s, got %s", strconv.FormatFloat(o.PayAmount, 'f', -1, 64), strconv.FormatFloat(paid, 'f', -1, 64))
|
||||||
}
|
}
|
||||||
return s.toPaid(ctx, o, tradeNo, paid, pk)
|
return s.toPaid(ctx, o, tradeNo, paid, pk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func paymentAmountToleranceForCurrency(currency string) float64 {
|
||||||
|
minorUnit := payment.CurrencyMinorUnit(currency)
|
||||||
|
if minorUnit <= 2 {
|
||||||
|
return amountToleranceCNY
|
||||||
|
}
|
||||||
|
return math.Pow10(-minorUnit) / 2
|
||||||
|
}
|
||||||
|
|
||||||
func isValidProviderAmount(amount float64) bool {
|
func isValidProviderAmount(amount float64) bool {
|
||||||
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
|
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -366,3 +366,55 @@ func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *t
|
|||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "easypay pid mismatch")
|
assert.ErrorContains(t, err, "easypay pid mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderNotificationMetadataRejectsAirwallexSnapshotMismatch(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
PaymentType: payment.TypeAirwallex,
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"schema_version": 2,
|
||||||
|
"merchant_id": "acct_expected",
|
||||||
|
"currency": "CNY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
|
||||||
|
"account_id": "acct_other",
|
||||||
|
"currency": "CNY",
|
||||||
|
"status": "SUCCEEDED",
|
||||||
|
})
|
||||||
|
assert.ErrorContains(t, err, "airwallex account_id mismatch")
|
||||||
|
|
||||||
|
err = validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
|
||||||
|
"account_id": "acct_expected",
|
||||||
|
"currency": "USD",
|
||||||
|
"status": "SUCCEEDED",
|
||||||
|
})
|
||||||
|
assert.ErrorContains(t, err, "airwallex currency mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderNotificationMetadataRejectsStripeCurrencyMismatch(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
PaymentType: payment.TypeStripe,
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"schema_version": 2,
|
||||||
|
"currency": "HKD",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := validateProviderNotificationMetadata(order, payment.TypeStripe, map[string]string{
|
||||||
|
"currency": "USD",
|
||||||
|
})
|
||||||
|
assert.ErrorContains(t, err, "stripe currency mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPaymentAmountToleranceForThreeDecimalCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("CNY"))
|
||||||
|
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("JPY"))
|
||||||
|
assert.InDelta(t, 0.0005, paymentAmountToleranceForCurrency("KWD"), 1e-12)
|
||||||
|
}
|
||||||
|
|||||||
@ -57,8 +57,17 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
|||||||
orderAmount = calculateCreditedBalance(req.Amount, cfg.BalanceRechargeMultiplier)
|
orderAmount = calculateCreditedBalance(req.Amount, cfg.BalanceRechargeMultiplier)
|
||||||
}
|
}
|
||||||
feeRate := cfg.RechargeFeeRate
|
feeRate := cfg.RechargeFeeRate
|
||||||
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
|
methodCurrency := payment.DefaultPaymentCurrency
|
||||||
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
|
if s.configService != nil {
|
||||||
|
methodCurrency, err = s.configService.ValidateMethodCurrencyConsistency(ctx, req.PaymentType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payAmountStr, payAmount, err := calculateCreateOrderPayAmount(limitAmount, feeRate, methodCurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
|
sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -66,6 +75,19 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
|||||||
if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
|
if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
selectedCurrency := payment.DefaultPaymentCurrency
|
||||||
|
if sel != nil {
|
||||||
|
selectedCurrency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||||
|
}
|
||||||
|
if selectedCurrency != methodCurrency {
|
||||||
|
payAmountStr, payAmount, err = calculateCreateOrderPayAmount(limitAmount, feeRate, selectedCurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := validateSelectedCreateOrderAmountCurrency(payAmountStr, sel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
|
oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -257,7 +279,7 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
|
|||||||
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
|
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
|
||||||
snapshot["merchant_id"] = merchantID
|
snapshot["merchant_id"] = merchantID
|
||||||
}
|
}
|
||||||
snapshot["currency"] = "CNY"
|
snapshot["currency"] = payment.DefaultPaymentCurrency
|
||||||
}
|
}
|
||||||
if providerKey == payment.TypeAlipay {
|
if providerKey == payment.TypeAlipay {
|
||||||
if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
|
if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
|
||||||
@ -269,6 +291,15 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
|
|||||||
snapshot["merchant_id"] = merchantID
|
snapshot["merchant_id"] = merchantID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if providerKey == payment.TypeStripe {
|
||||||
|
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
|
||||||
|
}
|
||||||
|
if providerKey == payment.TypeAirwallex {
|
||||||
|
if accountID := strings.TrimSpace(sel.Config["accountId"]); accountID != "" {
|
||||||
|
snapshot["merchant_id"] = accountID
|
||||||
|
}
|
||||||
|
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
|
||||||
|
}
|
||||||
|
|
||||||
if len(snapshot) == 1 {
|
if len(snapshot) == 1 {
|
||||||
return nil
|
return nil
|
||||||
@ -377,7 +408,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
|
|||||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
|
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
|
||||||
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
|
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
|
||||||
}
|
}
|
||||||
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
|
subject := s.buildPaymentSubject(plan, limitAmount, cfg, sel)
|
||||||
outTradeNo := order.OutTradeNo
|
outTradeNo := order.OutTradeNo
|
||||||
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
|
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -466,20 +497,24 @@ func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
|
|||||||
return sel.SupportedTypes
|
return sel.SupportedTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
|
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig, sel *payment.InstanceSelection) string {
|
||||||
if plan != nil {
|
if plan != nil {
|
||||||
if plan.ProductName != "" {
|
if plan.ProductName != "" {
|
||||||
return plan.ProductName
|
return plan.ProductName
|
||||||
}
|
}
|
||||||
return "Sub2API Subscription " + plan.Name
|
return "Sub2API Subscription " + plan.Name
|
||||||
}
|
}
|
||||||
amountStr := strconv.FormatFloat(limitAmount, 'f', 2, 64)
|
currency := payment.DefaultPaymentCurrency
|
||||||
|
if sel != nil {
|
||||||
|
currency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||||
|
}
|
||||||
|
amountStr := payment.FormatAmountForCurrency(limitAmount, currency)
|
||||||
pf := strings.TrimSpace(cfg.ProductNamePrefix)
|
pf := strings.TrimSpace(cfg.ProductNamePrefix)
|
||||||
sf := strings.TrimSpace(cfg.ProductNameSuffix)
|
sf := strings.TrimSpace(cfg.ProductNameSuffix)
|
||||||
if pf != "" || sf != "" {
|
if pf != "" || sf != "" {
|
||||||
return strings.TrimSpace(pf + " " + amountStr + " " + sf)
|
return strings.TrimSpace(pf + " " + amountStr + " " + sf)
|
||||||
}
|
}
|
||||||
return "Sub2API " + amountStr + " CNY"
|
return "Sub2API " + amountStr + " " + currency
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
|
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
|
||||||
@ -540,6 +575,44 @@ func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func calculateCreateOrderPayAmount(limitAmount, feeRate float64, currency string) (string, float64, error) {
|
||||||
|
if err := validateCreateOrderAmountCurrency(limitAmount, currency); err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
payAmountStr := payment.CalculatePayAmountForCurrency(limitAmount, feeRate, currency)
|
||||||
|
if _, err := payment.AmountToMinorUnit(payAmountStr, currency); err != nil {
|
||||||
|
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||||
|
WithMetadata(map[string]string{"currency": currency})
|
||||||
|
}
|
||||||
|
payAmount, err := strconv.ParseFloat(payAmountStr, 64)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", "invalid payment amount").
|
||||||
|
WithMetadata(map[string]string{"currency": currency})
|
||||||
|
}
|
||||||
|
return payAmountStr, payAmount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCreateOrderAmountCurrency(amount float64, currency string) error {
|
||||||
|
amountStr := strconv.FormatFloat(amount, 'f', -1, 64)
|
||||||
|
if _, err := payment.AmountToMinorUnit(amountStr, currency); err != nil {
|
||||||
|
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||||
|
WithMetadata(map[string]string{"currency": currency})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSelectedCreateOrderAmountCurrency(payAmount string, sel *payment.InstanceSelection) error {
|
||||||
|
if sel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
currency := paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||||
|
if _, err := payment.AmountToMinorUnit(payAmount, currency); err != nil {
|
||||||
|
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||||
|
WithMetadata(map[string]string{"currency": currency})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
|
func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
|
||||||
if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
|
if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
|
||||||
return false
|
return false
|
||||||
@ -596,6 +669,10 @@ func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest,
|
|||||||
PayURL: pr.PayURL,
|
PayURL: pr.PayURL,
|
||||||
QRCode: pr.QRCode,
|
QRCode: pr.QRCode,
|
||||||
ClientSecret: pr.ClientSecret,
|
ClientSecret: pr.ClientSecret,
|
||||||
|
IntentID: pr.IntentID,
|
||||||
|
Currency: pr.Currency,
|
||||||
|
CountryCode: pr.CountryCode,
|
||||||
|
PaymentEnv: pr.PaymentEnv,
|
||||||
OAuth: pr.OAuth,
|
OAuth: pr.OAuth,
|
||||||
JSAPI: pr.JSAPI,
|
JSAPI: pr.JSAPI,
|
||||||
JSAPIPayload: pr.JSAPI,
|
JSAPIPayload: pr.JSAPI,
|
||||||
|
|||||||
@ -188,6 +188,38 @@ func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey str
|
|||||||
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
|
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case payment.TypeStripe:
|
||||||
|
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
|
||||||
|
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
|
||||||
|
if actual == "" {
|
||||||
|
return fmt.Errorf("stripe notification missing currency")
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(expected, actual) {
|
||||||
|
return fmt.Errorf("stripe currency mismatch: expected %s, got %s", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case payment.TypeAirwallex:
|
||||||
|
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
|
||||||
|
actual := strings.TrimSpace(metadata["account_id"])
|
||||||
|
if actual == "" {
|
||||||
|
return fmt.Errorf("airwallex account_id missing")
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(expected, actual) {
|
||||||
|
return fmt.Errorf("airwallex account_id mismatch: expected %s, got %s", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
|
||||||
|
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
|
||||||
|
if actual == "" {
|
||||||
|
return fmt.Errorf("airwallex notification missing currency")
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(expected, actual) {
|
||||||
|
return fmt.Errorf("airwallex currency mismatch: expected %s, got %s", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if actual := strings.TrimSpace(metadata["status"]); actual != "" && !strings.EqualFold(actual, "SUCCEEDED") {
|
||||||
|
return fmt.Errorf("airwallex status mismatch: expected SUCCEEDED, got %s", actual)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -164,6 +164,30 @@ func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *te
|
|||||||
require.NotContains(t, snapshot, "pkey")
|
require.NotContains(t, snapshot, "pkey")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildPaymentOrderProviderSnapshot_IncludesProviderCurrency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
stripeSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
|
||||||
|
InstanceID: "77",
|
||||||
|
ProviderKey: payment.TypeStripe,
|
||||||
|
Config: map[string]string{
|
||||||
|
"currency": "hkd",
|
||||||
|
},
|
||||||
|
}, CreateOrderRequest{})
|
||||||
|
require.Equal(t, "HKD", stripeSnapshot["currency"])
|
||||||
|
|
||||||
|
airwallexSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
|
||||||
|
InstanceID: "78",
|
||||||
|
ProviderKey: payment.TypeAirwallex,
|
||||||
|
Config: map[string]string{
|
||||||
|
"currency": "usd",
|
||||||
|
"accountId": "acct-78",
|
||||||
|
},
|
||||||
|
}, CreateOrderRequest{})
|
||||||
|
require.Equal(t, "USD", airwallexSnapshot["currency"])
|
||||||
|
require.Equal(t, "acct-78", airwallexSnapshot["merchant_id"])
|
||||||
|
}
|
||||||
|
|
||||||
func valueOrEmpty(v *string) string {
|
func valueOrEmpty(v *string) string {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@ -91,6 +91,53 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateSelectedCreateOrderAmountCurrencyRejectsFractionalZeroDecimal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := validateSelectedCreateOrderAmountCurrency("100.50", &payment.InstanceSelection{
|
||||||
|
ProviderKey: payment.TypeStripe,
|
||||||
|
Config: map[string]string{"currency": "JPY"},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected fractional JPY amount to fail")
|
||||||
|
}
|
||||||
|
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
|
||||||
|
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCreateOrderPayAmountUsesCurrencyPrecision(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
amountStr, amount, err := calculateCreateOrderPayAmount(100, 2.5, "JPY")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if amountStr != "103" || amount != 103 {
|
||||||
|
t.Fatalf("JPY pay amount = (%q, %v), want (103, 103)", amountStr, amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
amountStr, amount, err = calculateCreateOrderPayAmount(12.345, 1, "KWD")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if amountStr != "12.469" || amount != 12.469 {
|
||||||
|
t.Fatalf("KWD pay amount = (%q, %v), want (12.469, 12.469)", amountStr, amount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCreateOrderPayAmountRejectsFractionalZeroDecimal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, _, err := calculateCreateOrderPayAmount(100.5, 0, "JPY")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected fractional JPY amount to fail")
|
||||||
|
}
|
||||||
|
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
|
||||||
|
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
||||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||||
|
|
||||||
|
|||||||
@ -226,10 +226,11 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
|
|||||||
if amt <= 0 {
|
if amt <= 0 {
|
||||||
amt = o.Amount
|
amt = o.Amount
|
||||||
}
|
}
|
||||||
if amt-o.Amount > amountToleranceCNY {
|
orderCurrency := PaymentOrderCurrency(o)
|
||||||
|
if amt-o.Amount > paymentAmountToleranceForCurrency(orderCurrency) {
|
||||||
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
|
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
|
||||||
}
|
}
|
||||||
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt)
|
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt, orderCurrency)
|
||||||
rr := strings.TrimSpace(reason)
|
rr := strings.TrimSpace(reason)
|
||||||
if rr == "" && o.RefundRequestReason != nil {
|
if rr == "" && o.RefundRequestReason != nil {
|
||||||
rr = *o.RefundRequestReason
|
rr = *o.RefundRequestReason
|
||||||
@ -339,13 +340,35 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
|
|||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = prov.Refund(ctx, payment.RefundRequest{
|
resp, err := prov.Refund(ctx, payment.RefundRequest{
|
||||||
TradeNo: p.Order.PaymentTradeNo,
|
TradeNo: p.Order.PaymentTradeNo,
|
||||||
OrderID: p.Order.OutTradeNo,
|
OrderID: p.Order.OutTradeNo,
|
||||||
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
|
Amount: formatGatewayRefundAmount(p.GatewayAmount, p.Order),
|
||||||
Reason: p.Reason,
|
Reason: p.Reason,
|
||||||
})
|
})
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return validateRefundProviderResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatGatewayRefundAmount(amount float64, order *dbent.PaymentOrder) string {
|
||||||
|
return payment.FormatAmountForCurrency(amount, PaymentOrderCurrency(order))
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRefundProviderResponse(resp *payment.RefundResponse) error {
|
||||||
|
if resp == nil {
|
||||||
|
return fmt.Errorf("payment refund response missing")
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(resp.Status)
|
||||||
|
switch status {
|
||||||
|
case payment.ProviderStatusSuccess, payment.ProviderStatusRefunded, payment.ProviderStatusPending:
|
||||||
|
return nil
|
||||||
|
case payment.ProviderStatusFailed:
|
||||||
|
return fmt.Errorf("payment refund failed: status %s", status)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("payment refund returned unknown status: %s", status)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRefundProvider creates a provider using the order's original instance config.
|
// getRefundProvider creates a provider using the order's original instance config.
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -184,3 +185,26 @@ func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.ErrorContains(t, err, "alipay app_id mismatch")
|
require.ErrorContains(t, err, "alipay app_id mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCalculateGatewayRefundAmountUsesCurrencyPrecision(t *testing.T) {
|
||||||
|
require.InDelta(t, 6.173, calculateGatewayRefundAmount(100, 12.345, 50, "KWD"), 1e-12)
|
||||||
|
require.InDelta(t, 12.345, calculateGatewayRefundAmount(100, 12.345, 100, "KWD"), 1e-12)
|
||||||
|
require.InDelta(t, 52, calculateGatewayRefundAmount(100, 103, 50, "JPY"), 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatGatewayRefundAmountUsesOrderCurrency(t *testing.T) {
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"currency": "KWD",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, "12.345", formatGatewayRefundAmount(12.345, order))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefundProviderResponseAcceptsPending(t *testing.T) {
|
||||||
|
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusPending}))
|
||||||
|
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusSuccess}))
|
||||||
|
require.Error(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusFailed}))
|
||||||
|
require.Error(t, validateRefundProviderResponse(nil))
|
||||||
|
}
|
||||||
|
|||||||
@ -97,6 +97,10 @@ type CreateOrderResponse struct {
|
|||||||
PayURL string `json:"pay_url,omitempty"`
|
PayURL string `json:"pay_url,omitempty"`
|
||||||
QRCode string `json:"qr_code,omitempty"`
|
QRCode string `json:"qr_code,omitempty"`
|
||||||
ClientSecret string `json:"client_secret,omitempty"`
|
ClientSecret string `json:"client_secret,omitempty"`
|
||||||
|
IntentID string `json:"intent_id,omitempty"`
|
||||||
|
Currency string `json:"currency,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
|
PaymentEnv string `json:"payment_env,omitempty"`
|
||||||
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
|
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
|
||||||
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
|
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
|
||||||
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
|
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
|
||||||
|
|||||||
@ -824,6 +824,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
|||||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||||
if account.Platform == PlatformOpenAI {
|
if account.Platform == PlatformOpenAI {
|
||||||
|
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
|
||||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||||
@ -1198,6 +1199,55 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseOpenAIRateLimitPlanType(body []byte) string {
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
errObj, ok := parsed["error"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
errType, _ := errObj["type"].(string)
|
||||||
|
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
planType, _ := errObj["plan_type"].(string)
|
||||||
|
return strings.ToLower(strings.TrimSpace(planType))
|
||||||
|
}
|
||||||
|
|
||||||
|
func persistOpenAI429PlanType(ctx context.Context, repo AccountRepository, account *Account, body []byte) {
|
||||||
|
if repo == nil || account == nil || account.Platform != PlatformOpenAI {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
planType := parseOpenAIRateLimitPlanType(body)
|
||||||
|
if planType == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
current := strings.TrimSpace(account.GetCredential("plan_type"))
|
||||||
|
if strings.EqualFold(current, planType) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := repo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{
|
||||||
|
Credentials: map[string]any{"plan_type": planType},
|
||||||
|
}); err != nil {
|
||||||
|
slog.Warn("openai_429_plan_type_sync_failed", "account_id", account.ID, "plan_type", planType, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Credentials == nil {
|
||||||
|
account.Credentials = make(map[string]any, 1)
|
||||||
|
}
|
||||||
|
account.Credentials["plan_type"] = planType
|
||||||
|
slog.Info("openai_429_plan_type_synced", "account_id", account.ID, "previous_plan_type", current, "plan_type", planType)
|
||||||
|
}
|
||||||
|
|
||||||
// handle529 处理529过载错误
|
// handle529 处理529过载错误
|
||||||
// 根据配置决定是否暂停账号调度及冷却时长
|
// 根据配置决定是否暂停账号调度及冷却时长
|
||||||
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||||
|
|||||||
@ -149,8 +149,10 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
|||||||
|
|
||||||
type openAI429SnapshotRepo struct {
|
type openAI429SnapshotRepo struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
rateLimitedID int64
|
rateLimitedID int64
|
||||||
updatedExtra map[string]any
|
updatedExtra map[string]any
|
||||||
|
bulkUpdatedIDs []int64
|
||||||
|
bulkUpdatedPayload AccountBulkUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||||
@ -163,6 +165,12 @@ func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *openAI429SnapshotRepo) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||||
|
r.bulkUpdatedIDs = append([]int64(nil), ids...)
|
||||||
|
r.bulkUpdatedPayload = updates
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
||||||
repo := &openAI429SnapshotRepo{}
|
repo := &openAI429SnapshotRepo{}
|
||||||
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
||||||
@ -192,6 +200,25 @@ func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandle429_OpenAISyncsObservedPlanType(t *testing.T) {
|
||||||
|
repo := &openAI429SnapshotRepo{}
|
||||||
|
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 124,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"plan_type": "plus"},
|
||||||
|
}
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","message":"limit reached","plan_type":"free","resets_at":1777283883}}`)
|
||||||
|
|
||||||
|
svc.handle429(context.Background(), account, http.Header{}, body)
|
||||||
|
|
||||||
|
require.Equal(t, []int64{account.ID}, repo.bulkUpdatedIDs)
|
||||||
|
require.Equal(t, "free", repo.bulkUpdatedPayload.Credentials["plan_type"])
|
||||||
|
require.Equal(t, "free", account.Credentials["plan_type"])
|
||||||
|
require.Equal(t, account.ID, repo.rateLimitedID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizedCodexLimits(t *testing.T) {
|
func TestNormalizedCodexLimits(t *testing.T) {
|
||||||
// Test the Normalize() method directly
|
// Test the Normalize() method directly
|
||||||
pUsed := 100.0
|
pUsed := 100.0
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
@ -87,6 +88,7 @@ type cachedGatewayForwardingSettings struct {
|
|||||||
metadataPassthrough bool
|
metadataPassthrough bool
|
||||||
cchSigning bool
|
cchSigning bool
|
||||||
anthropicCacheTTL1hInjection bool
|
anthropicCacheTTL1hInjection bool
|
||||||
|
rewriteMessageCacheControl bool
|
||||||
expiresAt int64 // unix nano
|
expiresAt int64 // unix nano
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,6 +99,16 @@ const gatewayForwardingCacheTTL = 60 * time.Second
|
|||||||
const gatewayForwardingErrorTTL = 5 * time.Second
|
const gatewayForwardingErrorTTL = 5 * time.Second
|
||||||
const gatewayForwardingDBTimeout = 5 * time.Second
|
const gatewayForwardingDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// cachedAntigravityUserAgentVersion 缓存 Antigravity UA 版本号(进程内缓存,60s TTL)
|
||||||
|
type cachedAntigravityUserAgentVersion struct {
|
||||||
|
version string
|
||||||
|
expiresAt int64 // unix nano
|
||||||
|
}
|
||||||
|
|
||||||
|
const antigravityUserAgentVersionCacheTTL = 60 * time.Second
|
||||||
|
const antigravityUserAgentVersionErrorTTL = 5 * time.Second
|
||||||
|
const antigravityUserAgentVersionDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||||
type DefaultSubscriptionGroupReader interface {
|
type DefaultSubscriptionGroupReader interface {
|
||||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||||
@ -108,13 +120,15 @@ type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[i
|
|||||||
|
|
||||||
// SettingService 系统设置服务
|
// SettingService 系统设置服务
|
||||||
type SettingService struct {
|
type SettingService struct {
|
||||||
settingRepo SettingRepository
|
settingRepo SettingRepository
|
||||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||||
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
|
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||||
version string // Application version
|
version string // Application version
|
||||||
webSearchManagerBuilder WebSearchManagerBuilder
|
webSearchManagerBuilder WebSearchManagerBuilder
|
||||||
|
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
||||||
|
antigravityUAVersionSF singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderDefaultGrantSettings struct {
|
type ProviderDefaultGrantSettings struct {
|
||||||
@ -809,6 +823,55 @@ func (s *SettingService) GetAvailableChannelsRuntime(ctx context.Context) Availa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAntigravityUserAgentVersion 返回 Antigravity 上游请求使用的版本号。
|
||||||
|
// 后台设置优先;为空、缺失或非法时回退到 ANTIGRAVITY_USER_AGENT_VERSION / 内置默认值。
|
||||||
|
func (s *SettingService) GetAntigravityUserAgentVersion(ctx context.Context) string {
|
||||||
|
fallback := antigravity.GetDefaultUserAgentVersion()
|
||||||
|
if s == nil || s.settingRepo == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if cached, ok := s.antigravityUAVersionCache.Load().(*cachedAntigravityUserAgentVersion); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _, _ := s.antigravityUAVersionSF.Do("antigravity_user_agent_version", func() (any, error) {
|
||||||
|
if cached, ok := s.antigravityUAVersionCache.Load().(*cachedAntigravityUserAgentVersion); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.version, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), antigravityUserAgentVersionDBTimeout)
|
||||||
|
defer cancel()
|
||||||
|
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyAntigravityUserAgentVersion)
|
||||||
|
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||||
|
slog.Warn("failed to get antigravity user agent version setting", "error", err)
|
||||||
|
s.antigravityUAVersionCache.Store(&cachedAntigravityUserAgentVersion{
|
||||||
|
version: fallback,
|
||||||
|
expiresAt: time.Now().Add(antigravityUserAgentVersionErrorTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
version := antigravity.NormalizeUserAgentVersion(value)
|
||||||
|
if version == "" {
|
||||||
|
version = fallback
|
||||||
|
}
|
||||||
|
s.antigravityUAVersionCache.Store(&cachedAntigravityUserAgentVersion{
|
||||||
|
version: version,
|
||||||
|
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return version, nil
|
||||||
|
})
|
||||||
|
if version, ok := result.(string); ok && version != "" {
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
||||||
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
||||||
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||||
@ -1584,6 +1647,8 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
|
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
|
||||||
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
|
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
|
||||||
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
||||||
|
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
||||||
|
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
||||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||||
@ -1652,8 +1717,18 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
|||||||
metadataPassthrough: settings.EnableMetadataPassthrough,
|
metadataPassthrough: settings.EnableMetadataPassthrough,
|
||||||
cchSigning: settings.EnableCCHSigning,
|
cchSigning: settings.EnableCCHSigning,
|
||||||
anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||||
|
rewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
||||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||||
})
|
})
|
||||||
|
s.antigravityUAVersionSF.Forget("antigravity_user_agent_version")
|
||||||
|
antigravityUserAgentVersion := antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
||||||
|
if antigravityUserAgentVersion == "" {
|
||||||
|
antigravityUserAgentVersion = antigravity.GetDefaultUserAgentVersion()
|
||||||
|
}
|
||||||
|
s.antigravityUAVersionCache.Store(&cachedAntigravityUserAgentVersion{
|
||||||
|
version: antigravityUserAgentVersion,
|
||||||
|
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||||
@ -1664,6 +1739,10 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SettingService) defaultRewriteMessageCacheControl() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -1815,17 +1894,18 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type gatewayForwardingSettingsResult struct {
|
type gatewayForwardingSettingsResult struct {
|
||||||
fp, mp, cch, cacheTTL1h bool
|
fp, mp, cch, cacheTTL1h, rewriteMessageCacheControl bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
|
func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
|
||||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||||
if time.Now().UnixNano() < cached.expiresAt {
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
return gatewayForwardingSettingsResult{
|
return gatewayForwardingSettingsResult{
|
||||||
fp: cached.fingerprintUnification,
|
fp: cached.fingerprintUnification,
|
||||||
mp: cached.metadataPassthrough,
|
mp: cached.metadataPassthrough,
|
||||||
cch: cached.cchSigning,
|
cch: cached.cchSigning,
|
||||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||||
|
rewriteMessageCacheControl: cached.rewriteMessageCacheControl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1833,10 +1913,11 @@ func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context)
|
|||||||
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
|
||||||
if time.Now().UnixNano() < cached.expiresAt {
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
return gatewayForwardingSettingsResult{
|
return gatewayForwardingSettingsResult{
|
||||||
fp: cached.fingerprintUnification,
|
fp: cached.fingerprintUnification,
|
||||||
mp: cached.metadataPassthrough,
|
mp: cached.metadataPassthrough,
|
||||||
cch: cached.cchSigning,
|
cch: cached.cchSigning,
|
||||||
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
cacheTTL1h: cached.anthropicCacheTTL1hInjection,
|
||||||
|
rewriteMessageCacheControl: cached.rewriteMessageCacheControl,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1847,6 +1928,7 @@ func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context)
|
|||||||
SettingKeyEnableMetadataPassthrough,
|
SettingKeyEnableMetadataPassthrough,
|
||||||
SettingKeyEnableCCHSigning,
|
SettingKeyEnableCCHSigning,
|
||||||
SettingKeyEnableAnthropicCacheTTL1hInjection,
|
SettingKeyEnableAnthropicCacheTTL1hInjection,
|
||||||
|
SettingKeyRewriteMessageCacheControl,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to get gateway forwarding settings", "error", err)
|
slog.Warn("failed to get gateway forwarding settings", "error", err)
|
||||||
@ -1855,9 +1937,10 @@ func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context)
|
|||||||
metadataPassthrough: false,
|
metadataPassthrough: false,
|
||||||
cchSigning: false,
|
cchSigning: false,
|
||||||
anthropicCacheTTL1hInjection: false,
|
anthropicCacheTTL1hInjection: false,
|
||||||
|
rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl(),
|
||||||
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
|
||||||
})
|
})
|
||||||
return gatewayForwardingSettingsResult{fp: true}, nil
|
return gatewayForwardingSettingsResult{fp: true, rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl()}, nil
|
||||||
}
|
}
|
||||||
fp := true
|
fp := true
|
||||||
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
|
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
|
||||||
@ -1866,14 +1949,25 @@ func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context)
|
|||||||
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
|
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
|
||||||
cch := values[SettingKeyEnableCCHSigning] == "true"
|
cch := values[SettingKeyEnableCCHSigning] == "true"
|
||||||
cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||||
|
rewriteMessageCacheControl := s.defaultRewriteMessageCacheControl()
|
||||||
|
if v, ok := values[SettingKeyRewriteMessageCacheControl]; ok && v != "" {
|
||||||
|
rewriteMessageCacheControl = v == "true"
|
||||||
|
}
|
||||||
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
|
||||||
fingerprintUnification: fp,
|
fingerprintUnification: fp,
|
||||||
metadataPassthrough: mp,
|
metadataPassthrough: mp,
|
||||||
cchSigning: cch,
|
cchSigning: cch,
|
||||||
anthropicCacheTTL1hInjection: cacheTTL1h,
|
anthropicCacheTTL1hInjection: cacheTTL1h,
|
||||||
|
rewriteMessageCacheControl: rewriteMessageCacheControl,
|
||||||
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
|
||||||
})
|
})
|
||||||
return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
|
return gatewayForwardingSettingsResult{
|
||||||
|
fp: fp,
|
||||||
|
mp: mp,
|
||||||
|
cch: cch,
|
||||||
|
cacheTTL1h: cacheTTL1h,
|
||||||
|
rewriteMessageCacheControl: rewriteMessageCacheControl,
|
||||||
|
}, nil
|
||||||
})
|
})
|
||||||
if r, ok := val.(gatewayForwardingSettingsResult); ok {
|
if r, ok := val.(gatewayForwardingSettingsResult); ok {
|
||||||
return r
|
return r
|
||||||
@ -1894,6 +1988,11 @@ func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Conte
|
|||||||
return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
|
return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRewriteMessageCacheControlEnabled 检查是否启用 messages cache_control 改写。
|
||||||
|
func (s *SettingService) IsRewriteMessageCacheControlEnabled(ctx context.Context) bool {
|
||||||
|
return s.getGatewayForwardingSettingsCached(ctx).rewriteMessageCacheControl
|
||||||
|
}
|
||||||
|
|
||||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
||||||
@ -2358,6 +2457,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
// 分组隔离(默认不允许未分组 Key 调度)
|
// 分组隔离(默认不允许未分组 Key 调度)
|
||||||
SettingKeyAllowUngroupedKeyScheduling: "false",
|
SettingKeyAllowUngroupedKeyScheduling: "false",
|
||||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
||||||
|
SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()),
|
||||||
|
SettingKeyAntigravityUserAgentVersion: "",
|
||||||
SettingPaymentVisibleMethodAlipaySource: "",
|
SettingPaymentVisibleMethodAlipaySource: "",
|
||||||
SettingPaymentVisibleMethodWxpaySource: "",
|
SettingPaymentVisibleMethodWxpaySource: "",
|
||||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||||
@ -2734,6 +2835,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||||
result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
|
||||||
|
if v, ok := settings[SettingKeyRewriteMessageCacheControl]; ok && v != "" {
|
||||||
|
result.RewriteMessageCacheControl = v == "true"
|
||||||
|
} else {
|
||||||
|
result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl()
|
||||||
|
}
|
||||||
|
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
||||||
|
|
||||||
// Web search emulation: quick enabled check from the JSON config
|
// Web search emulation: quick enabled check from the JSON config
|
||||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -48,6 +49,41 @@ func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
|
|||||||
panic("unexpected Delete call")
|
panic("unexpected Delete call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type settingAntigravityUARepoStub struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
if value, ok := s.values[key]; ok {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *settingAntigravityUARepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
type defaultSubGroupReaderStub struct {
|
type defaultSubGroupReaderStub struct {
|
||||||
byID map[int64]*Group
|
byID map[int64]*Group
|
||||||
errBy map[int64]error
|
errBy map[int64]error
|
||||||
@ -243,6 +279,41 @@ func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler
|
|||||||
require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
|
require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSettingService_UpdateSettings_AntigravityUserAgentVersion(t *testing.T) {
|
||||||
|
repo := &settingUpdateRepoStub{}
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||||
|
AntigravityUserAgentVersion: "1.23.2",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "1.23.2", repo.updates[SettingKeyAntigravityUserAgentVersion])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingService_GetAntigravityUserAgentVersion_Precedence(t *testing.T) {
|
||||||
|
t.Run("后台设置优先", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{
|
||||||
|
SettingKeyAntigravityUserAgentVersion: "1.24.0",
|
||||||
|
}}, &config.Config{})
|
||||||
|
|
||||||
|
require.Equal(t, "1.24.0", svc.GetAntigravityUserAgentVersion(context.Background()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("空值回退配置默认值", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{
|
||||||
|
SettingKeyAntigravityUserAgentVersion: "",
|
||||||
|
}}, &config.Config{})
|
||||||
|
|
||||||
|
require.Equal(t, antigravity.GetDefaultUserAgentVersion(), svc.GetAntigravityUserAgentVersion(context.Background()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("缺失回退配置默认值", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{}}, &config.Config{})
|
||||||
|
|
||||||
|
require.Equal(t, antigravity.GetDefaultUserAgentVersion(), svc.GetAntigravityUserAgentVersion(context.Background()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
|
func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
|
||||||
repo := &settingUpdateRepoStub{}
|
repo := &settingUpdateRepoStub{}
|
||||||
svc := NewSettingService(repo, &config.Config{})
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|||||||
@ -168,10 +168,12 @@ type SystemSettings struct {
|
|||||||
BackendModeEnabled bool
|
BackendModeEnabled bool
|
||||||
|
|
||||||
// Gateway forwarding behavior
|
// Gateway forwarding behavior
|
||||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||||
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||||
|
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
||||||
|
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,7 +175,7 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key, vertexServiceAccountProxyURL(account))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -183,7 +185,36 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
func vertexServiceAccountProxyURL(account *Account) string {
|
||||||
|
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVertexServiceAccountHTTPClient(proxyURL string) (*http.Client, error) {
|
||||||
|
proxyURL = strings.TrimSpace(proxyURL)
|
||||||
|
if proxyURL == "" {
|
||||||
|
return &http.Client{Timeout: 15 * time.Second}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, parsedProxy, err := proxyurl.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected default transport type %T", http.DefaultTransport)
|
||||||
|
}
|
||||||
|
transport := defaultTransport.Clone()
|
||||||
|
transport.Proxy = nil
|
||||||
|
if err := proxyutil.ConfigureTransportProxy(transport, parsedProxy); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Client{Timeout: 15 * time.Second, Transport: transport}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey, proxyURL string) (string, time.Duration, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"iss": key.ClientEmail,
|
"iss": key.ClientEmail,
|
||||||
@ -215,7 +246,10 @@ func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAc
|
|||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 15 * time.Second}
|
client, err := newVertexServiceAccountHTTPClient(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, fmt.Errorf("configure service account token proxy: %w", err)
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
||||||
|
|||||||
@ -1,8 +1,17 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@ -75,3 +84,70 @@ func TestParseVertexServiceAccountKey(t *testing.T) {
|
|||||||
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
||||||
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVertexServiceAccountProxyURL(t *testing.T) {
|
||||||
|
proxyID := int64(7)
|
||||||
|
account := &Account{
|
||||||
|
ProxyID: &proxyID,
|
||||||
|
Proxy: &Proxy{
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "proxy.example.com",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, "http://proxy.example.com:8080", vertexServiceAccountProxyURL(account))
|
||||||
|
require.Empty(t, vertexServiceAccountProxyURL(&Account{Proxy: account.Proxy}))
|
||||||
|
require.Empty(t, vertexServiceAccountProxyURL(&Account{ProxyID: &proxyID}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeVertexServiceAccountTokenUsesProxy(t *testing.T) {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pemBytes := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
seenProxyRequest := make(chan string, 1)
|
||||||
|
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
seenProxyRequest <- r.URL.String()
|
||||||
|
require.Equal(t, "oauth2.googleapis.com", r.URL.Host)
|
||||||
|
require.Equal(t, "/token", r.URL.Path)
|
||||||
|
require.NoError(t, r.ParseForm())
|
||||||
|
require.Equal(t, "urn:ietf:params:oauth:grant-type:jwt-bearer", r.PostForm.Get("grant_type"))
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
require.NoError(t, json.NewEncoder(w).Encode(vertexTokenResponse{
|
||||||
|
AccessToken: "proxied-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
key := &vertexServiceAccountKey{
|
||||||
|
ClientEmail: "svc@example.iam.gserviceaccount.com",
|
||||||
|
PrivateKey: string(pemBytes),
|
||||||
|
TokenURI: "http://oauth2.googleapis.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
token, ttl, err := exchangeVertexServiceAccountToken(contextWithTestTimeout(t), key, proxy.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "proxied-token", token)
|
||||||
|
require.Equal(t, 55*time.Minute, ttl)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-seenProxyRequest:
|
||||||
|
require.Equal(t, "http://oauth2.googleapis.com/token", got)
|
||||||
|
default:
|
||||||
|
t.Fatal("token request did not reach proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextWithTestTimeout(t *testing.T) context.Context {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
@ -398,6 +399,7 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
|||||||
svc := NewSettingService(settingRepo, cfg)
|
svc := NewSettingService(settingRepo, cfg)
|
||||||
svc.SetDefaultSubscriptionGroupReader(groupRepo)
|
svc.SetDefaultSubscriptionGroupReader(groupRepo)
|
||||||
svc.SetProxyRepository(proxyRepo)
|
svc.SetProxyRepository(proxyRepo)
|
||||||
|
antigravity.SetUserAgentVersionResolver(svc.GetAntigravityUserAgentVersion)
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -231,6 +231,9 @@ TOTP_ENCRYPTION_KEY=
|
|||||||
#
|
#
|
||||||
# Antigravity OAuth client_secret(用于 Antigravity OAuth 登录流)
|
# Antigravity OAuth client_secret(用于 Antigravity OAuth 登录流)
|
||||||
# ANTIGRAVITY_OAUTH_CLIENT_SECRET=
|
# ANTIGRAVITY_OAUTH_CLIENT_SECRET=
|
||||||
|
#
|
||||||
|
# Antigravity User-Agent 版本号(后台设置 antigravity_user_agent_version 优先;留空使用内置默认 1.23.2)
|
||||||
|
# ANTIGRAVITY_USER_AGENT_VERSION=
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Rate Limiting (Optional)
|
# Rate Limiting (Optional)
|
||||||
|
|||||||
@ -129,7 +129,7 @@ security:
|
|||||||
# 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖)
|
# 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖)
|
||||||
# Note: __CSP_NONCE__ will be replaced with 'nonce-xxx' at request time for inline script security
|
# Note: __CSP_NONCE__ will be replaced with 'nonce-xxx' at request time for inline script security
|
||||||
# 注意:__CSP_NONCE__ 会在请求时被替换为 'nonce-xxx',用于内联脚本安全
|
# 注意:__CSP_NONCE__ 会在请求时被替换为 'nonce-xxx',用于内联脚本安全
|
||||||
policy: "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
policy: "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com https://checkout.airwallex.com https://checkout-demo.airwallex.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||||
proxy_probe:
|
proxy_probe:
|
||||||
# Allow skipping TLS verification for proxy probe (debug only)
|
# Allow skipping TLS verification for proxy probe (debug only)
|
||||||
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
|
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
|
||||||
|
|||||||
@ -128,6 +128,7 @@ services:
|
|||||||
# SECURITY: This repo does not embed third-party client_secret.
|
# SECURITY: This repo does not embed third-party client_secret.
|
||||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||||
|
- ANTIGRAVITY_USER_AGENT_VERSION=${ANTIGRAVITY_USER_AGENT_VERSION:-}
|
||||||
|
|
||||||
# =======================================================================
|
# =======================================================================
|
||||||
# Security Configuration (URL Allowlist)
|
# Security Configuration (URL Allowlist)
|
||||||
|
|||||||
@ -93,6 +93,7 @@ services:
|
|||||||
# SECURITY: This repo does not embed third-party client_secret.
|
# SECURITY: This repo does not embed third-party client_secret.
|
||||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||||
|
- ANTIGRAVITY_USER_AGENT_VERSION=${ANTIGRAVITY_USER_AGENT_VERSION:-}
|
||||||
|
|
||||||
# =======================================================================
|
# =======================================================================
|
||||||
# Image Generation Stream & Concurrency
|
# Image Generation Stream & Concurrency
|
||||||
|
|||||||
@ -90,6 +90,7 @@ services:
|
|||||||
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
|
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
|
||||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||||
|
- ANTIGRAVITY_USER_AGENT_VERSION=${ANTIGRAVITY_USER_AGENT_VERSION:-}
|
||||||
|
|
||||||
# --- Security ---
|
# --- Security ---
|
||||||
- SECURITY_URL_ALLOWLIST_ENABLED=${SECURITY_URL_ALLOWLIST_ENABLED:-false}
|
- SECURITY_URL_ALLOWLIST_ENABLED=${SECURITY_URL_ALLOWLIST_ENABLED:-false}
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
"test:coverage": "vitest run --coverage"
|
"test:coverage": "vitest run --coverage"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@airwallex/components-sdk": "^1.30.2",
|
||||||
"@lobehub/icons": "^4.0.2",
|
"@lobehub/icons": "^4.0.2",
|
||||||
"@tanstack/vue-virtual": "^3.13.23",
|
"@tanstack/vue-virtual": "^3.13.23",
|
||||||
"@vueuse/core": "^10.7.0",
|
"@vueuse/core": "^10.7.0",
|
||||||
|
|||||||
15
frontend/pnpm-lock.yaml
generated
15
frontend/pnpm-lock.yaml
generated
@ -8,6 +8,9 @@ importers:
|
|||||||
|
|
||||||
.:
|
.:
|
||||||
dependencies:
|
dependencies:
|
||||||
|
'@airwallex/components-sdk':
|
||||||
|
specifier: ^1.30.2
|
||||||
|
version: 1.30.2
|
||||||
'@lobehub/icons':
|
'@lobehub/icons':
|
||||||
specifier: ^4.0.2
|
specifier: ^4.0.2
|
||||||
version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
|
version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
|
||||||
@ -129,6 +132,12 @@ importers:
|
|||||||
|
|
||||||
packages:
|
packages:
|
||||||
|
|
||||||
|
'@airwallex/airtracker@3.2.0':
|
||||||
|
resolution: {integrity: sha512-PKE5N38ajTVg6ph9JzLpWsICNjqLtf/wWudNVU3UPX9SVy2I5s5ITc281sMSD8+LIE6RJoGjGTO+VYP/io5kig==}
|
||||||
|
|
||||||
|
'@airwallex/components-sdk@1.30.2':
|
||||||
|
resolution: {integrity: sha512-BGwAPCACwOJm8XNxDxJGMq1o/73D9+ZWifvp5YHvfgIwxg1RGVCIME0tP1g8cash3fVLHgl7xObyS1QbIOSDXw==}
|
||||||
|
|
||||||
'@alloc/quick-lru@5.2.0':
|
'@alloc/quick-lru@5.2.0':
|
||||||
resolution: {integrity: sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==}
|
resolution: {integrity: sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==}
|
||||||
engines: {node: '>=10'}
|
engines: {node: '>=10'}
|
||||||
@ -4516,6 +4525,12 @@ packages:
|
|||||||
|
|
||||||
snapshots:
|
snapshots:
|
||||||
|
|
||||||
|
'@airwallex/airtracker@3.2.0': {}
|
||||||
|
|
||||||
|
'@airwallex/components-sdk@1.30.2':
|
||||||
|
dependencies:
|
||||||
|
'@airwallex/airtracker': 3.2.0
|
||||||
|
|
||||||
'@alloc/quick-lru@5.2.0': {}
|
'@alloc/quick-lru@5.2.0': {}
|
||||||
|
|
||||||
'@ampproject/remapping@2.3.0':
|
'@ampproject/remapping@2.3.0':
|
||||||
|
|||||||
@ -478,6 +478,8 @@ export interface SystemSettings {
|
|||||||
enable_metadata_passthrough: boolean;
|
enable_metadata_passthrough: boolean;
|
||||||
enable_cch_signing: boolean;
|
enable_cch_signing: boolean;
|
||||||
enable_anthropic_cache_ttl_1h_injection: boolean;
|
enable_anthropic_cache_ttl_1h_injection: boolean;
|
||||||
|
rewrite_message_cache_control: boolean;
|
||||||
|
antigravity_user_agent_version: string;
|
||||||
web_search_emulation_enabled?: boolean;
|
web_search_emulation_enabled?: boolean;
|
||||||
|
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
@ -675,6 +677,8 @@ export interface UpdateSettingsRequest {
|
|||||||
enable_metadata_passthrough?: boolean;
|
enable_metadata_passthrough?: boolean;
|
||||||
enable_cch_signing?: boolean;
|
enable_cch_signing?: boolean;
|
||||||
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
||||||
|
rewrite_message_cache_control?: boolean;
|
||||||
|
antigravity_user_agent_version?: string;
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
payment_enabled?: boolean;
|
payment_enabled?: boolean;
|
||||||
risk_control_enabled?: boolean;
|
risk_control_enabled?: boolean;
|
||||||
|
|||||||
13
frontend/src/assets/icons/airwallex.svg
Normal file
13
frontend/src/assets/icons/airwallex.svg
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48.1 32.3" role="img" aria-label="Airwallex">
|
||||||
|
<defs>
|
||||||
|
<linearGradient id="airwallex-mark" x1="0" y1="2" x2="48" y2="30" gradientUnits="userSpaceOnUse">
|
||||||
|
<stop offset="0" stop-color="#FF4F42"/>
|
||||||
|
<stop offset="1" stop-color="#FF8E3C"/>
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
<path
|
||||||
|
fill="url(#airwallex-mark)"
|
||||||
|
d="M312.76 380.53a6 6 0 0 1 1.42 6.42l-3.18 8.58a6.89 6.89 0 0 1-5 4.47 6.8 6.8 0 0 1-1.3.13 6.58 6.58 0 0 1-5.08-2.4l-19-22.69a.42.42 0 0 0-.71.12l-6.17 16.67a.42.42 0 0 0 .55.54l7.57-3.09a3.34 3.34 0 0 1 4.44 2.08 3.47 3.47 0 0 1-2 4.24l-9.89 4a5.93 5.93 0 0 1-7.88-7.56l7.29-19.68a6.84 6.84 0 0 1 11.68-2l10.88 13 10-4.08a5.84 5.84 0 0 1 6.38 1.24ZM307 387.07a.42.42 0 0 0-.55-.54l-5.53 2.26 3.32 4a.42.42 0 0 0 .71-.13Z"
|
||||||
|
transform="translate(-266.13 -367.85)"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 858 B |
@ -130,13 +130,16 @@ const props = defineProps<{
|
|||||||
/** 充值金额 (base amount before fee) = pay_amount - fee = pay_amount / (1 + fee_rate/100) */
|
/** 充值金额 (base amount before fee) = pay_amount - fee = pay_amount / (1 + fee_rate/100) */
|
||||||
const baseAmount = computed(() => {
|
const baseAmount = computed(() => {
|
||||||
if (!props.order) return 0
|
if (!props.order) return 0
|
||||||
if (props.order.fee_rate <= 0) return props.order.pay_amount
|
const feeRate = Number(props.order.fee_rate) || 0
|
||||||
return props.order.pay_amount / (1 + props.order.fee_rate / 100)
|
if (feeRate <= 0) return props.order.pay_amount
|
||||||
|
return props.order.pay_amount / (1 + feeRate / 100)
|
||||||
})
|
})
|
||||||
|
|
||||||
/** 手续费 = pay_amount - baseAmount */
|
/** 手续费 = pay_amount - baseAmount */
|
||||||
const feeAmount = computed(() => {
|
const feeAmount = computed(() => {
|
||||||
if (!props.order || props.order.fee_rate <= 0) return 0
|
if (!props.order) return 0
|
||||||
|
const feeRate = Number(props.order.fee_rate) || 0
|
||||||
|
if (feeRate <= 0) return 0
|
||||||
return props.order.pay_amount - baseAmount.value
|
return props.order.pay_amount - baseAmount.value
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -212,6 +212,7 @@ const paymentTypeFilterOptions = computed(() => [
|
|||||||
{ value: 'alipay', label: t('payment.methods.alipay') },
|
{ value: 'alipay', label: t('payment.methods.alipay') },
|
||||||
{ value: 'wxpay', label: t('payment.methods.wxpay') },
|
{ value: 'wxpay', label: t('payment.methods.wxpay') },
|
||||||
{ value: 'stripe', label: t('payment.methods.stripe') },
|
{ value: 'stripe', label: t('payment.methods.stripe') },
|
||||||
|
{ value: 'airwallex', label: t('payment.methods.airwallex') },
|
||||||
])
|
])
|
||||||
|
|
||||||
const orderTypeFilterOptions = computed(() => [
|
const orderTypeFilterOptions = computed(() => [
|
||||||
|
|||||||
@ -20,7 +20,7 @@
|
|||||||
@click="method.available && emit('select', method.type)"
|
@click="method.available && emit('select', method.type)"
|
||||||
>
|
>
|
||||||
<span class="flex items-center gap-2">
|
<span class="flex items-center gap-2">
|
||||||
<img :src="methodIcon(method.type)" :alt="t(`payment.methods.${method.type}`)" class="h-7 w-7" />
|
<img :src="methodIcon(method.type)" :alt="t(`payment.methods.${method.type}`)" class="h-7 w-7 object-contain" />
|
||||||
<span class="flex flex-col items-start leading-none">
|
<span class="flex flex-col items-start leading-none">
|
||||||
<span class="text-base font-semibold">{{ t(`payment.methods.${method.type}`) }}</span>
|
<span class="text-base font-semibold">{{ t(`payment.methods.${method.type}`) }}</span>
|
||||||
<span
|
<span
|
||||||
@ -43,6 +43,7 @@ import { METHOD_ORDER } from './providerConfig'
|
|||||||
import alipayIcon from '@/assets/icons/alipay.svg'
|
import alipayIcon from '@/assets/icons/alipay.svg'
|
||||||
import wxpayIcon from '@/assets/icons/wxpay.svg'
|
import wxpayIcon from '@/assets/icons/wxpay.svg'
|
||||||
import stripeIcon from '@/assets/icons/stripe.svg'
|
import stripeIcon from '@/assets/icons/stripe.svg'
|
||||||
|
import airwallexIcon from '@/assets/icons/airwallex.svg'
|
||||||
|
|
||||||
export interface PaymentMethodOption {
|
export interface PaymentMethodOption {
|
||||||
type: string
|
type: string
|
||||||
@ -65,6 +66,7 @@ const METHOD_ICONS: Record<string, string> = {
|
|||||||
alipay: alipayIcon,
|
alipay: alipayIcon,
|
||||||
wxpay: wxpayIcon,
|
wxpay: wxpayIcon,
|
||||||
stripe: stripeIcon,
|
stripe: stripeIcon,
|
||||||
|
airwallex: airwallexIcon,
|
||||||
}
|
}
|
||||||
|
|
||||||
const sortedMethods = computed(() => {
|
const sortedMethods = computed(() => {
|
||||||
@ -79,6 +81,7 @@ const sortedMethods = computed(() => {
|
|||||||
function methodIcon(type: string): string {
|
function methodIcon(type: string): string {
|
||||||
if (type.includes('alipay')) return METHOD_ICONS.alipay
|
if (type.includes('alipay')) return METHOD_ICONS.alipay
|
||||||
if (type.includes('wxpay')) return METHOD_ICONS.wxpay
|
if (type.includes('wxpay')) return METHOD_ICONS.wxpay
|
||||||
|
if (type === 'airwallex') return METHOD_ICONS.airwallex
|
||||||
return METHOD_ICONS[type] || alipayIcon
|
return METHOD_ICONS[type] || alipayIcon
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,6 +89,7 @@ function methodSelectedClass(type: string): string {
|
|||||||
if (type.includes('alipay')) return 'border-[#02A9F1] bg-blue-50 text-gray-900 shadow-sm dark:bg-blue-950 dark:text-gray-100'
|
if (type.includes('alipay')) return 'border-[#02A9F1] bg-blue-50 text-gray-900 shadow-sm dark:bg-blue-950 dark:text-gray-100'
|
||||||
if (type.includes('wxpay')) return 'border-[#09BB07] bg-green-50 text-gray-900 shadow-sm dark:bg-green-950 dark:text-gray-100'
|
if (type.includes('wxpay')) return 'border-[#09BB07] bg-green-50 text-gray-900 shadow-sm dark:bg-green-950 dark:text-gray-100'
|
||||||
if (type === 'stripe') return 'border-[#676BE5] bg-indigo-50 text-gray-900 shadow-sm dark:bg-indigo-950 dark:text-gray-100'
|
if (type === 'stripe') return 'border-[#676BE5] bg-indigo-50 text-gray-900 shadow-sm dark:bg-indigo-950 dark:text-gray-100'
|
||||||
|
if (type === 'airwallex') return 'border-[#FF6B3D] bg-orange-50 text-gray-900 shadow-sm dark:border-[#FF8E3C] dark:bg-orange-950 dark:text-gray-100'
|
||||||
return 'border-primary-500 bg-primary-50 text-gray-900 shadow-sm dark:bg-primary-950 dark:text-gray-100'
|
return 'border-primary-500 bg-primary-50 text-gray-900 shadow-sm dark:bg-primary-950 dark:text-gray-100'
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@ -149,6 +149,12 @@
|
|||||||
<svg v-else class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M2.458 12C3.732 7.943 7.523 5 12 5c4.478 0 8.268 2.943 9.542 7-1.274 4.057-5.064 7-9.542 7-4.477 0-8.268-2.943-9.542-7z" /></svg>
|
<svg v-else class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M2.458 12C3.732 7.943 7.523 5 12 5c4.478 0 8.268 2.943 9.542 7-1.274 4.057-5.064 7-9.542 7-4.477 0-8.268-2.943-9.542-7z" /></svg>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
<Select
|
||||||
|
v-else-if="field.options?.length"
|
||||||
|
v-model="config[field.key]"
|
||||||
|
:options="field.options"
|
||||||
|
:searchable="field.options.length > 5"
|
||||||
|
/>
|
||||||
<input
|
<input
|
||||||
v-else
|
v-else
|
||||||
type="text"
|
type="text"
|
||||||
@ -156,6 +162,9 @@
|
|||||||
class="input"
|
class="input"
|
||||||
:placeholder="field.defaultValue || ''"
|
:placeholder="field.defaultValue || ''"
|
||||||
/>
|
/>
|
||||||
|
<p v-if="field.hintKey" class="mt-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t(field.hintKey) }}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -177,14 +186,17 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Stripe webhook hint -->
|
<!-- 服务商 Webhook 提示 -->
|
||||||
<div v-if="stripeWebhookUrl" class="mt-3 rounded-lg border border-blue-200 bg-blue-50 p-3 dark:border-blue-800/50 dark:bg-blue-900/20">
|
<div v-if="providerWebhookUrl" class="mt-3 rounded-lg border border-blue-200 bg-blue-50 p-3 dark:border-blue-800/50 dark:bg-blue-900/20">
|
||||||
<p class="text-xs text-blue-700 dark:text-blue-300">
|
<p class="text-xs text-blue-700 dark:text-blue-300">
|
||||||
{{ t('admin.settings.payment.stripeWebhookHint') }}
|
{{ t(providerWebhookHint) }}
|
||||||
</p>
|
</p>
|
||||||
<code class="mt-1 block break-all rounded bg-blue-100 px-2 py-1 text-xs text-blue-800 dark:bg-blue-900/40 dark:text-blue-200">
|
<code class="mt-1 block break-all rounded bg-blue-100 px-2 py-1 text-xs text-blue-800 dark:bg-blue-900/40 dark:text-blue-200">
|
||||||
{{ stripeWebhookUrl }}
|
{{ providerWebhookUrl }}
|
||||||
</code>
|
</code>
|
||||||
|
<p v-if="form.provider_key === 'stripe'" class="mt-2 text-xs leading-relaxed text-blue-700 dark:text-blue-300">
|
||||||
|
{{ t('admin.settings.payment.stripeWebhookApiVersionHint', { version: STRIPE_SDK_API_VERSION }) }}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -266,6 +278,7 @@ import {
|
|||||||
WEBHOOK_PATHS,
|
WEBHOOK_PATHS,
|
||||||
PAYMENT_MODE_QRCODE,
|
PAYMENT_MODE_QRCODE,
|
||||||
PAYMENT_MODE_POPUP,
|
PAYMENT_MODE_POPUP,
|
||||||
|
STRIPE_SDK_API_VERSION,
|
||||||
getAvailableTypes,
|
getAvailableTypes,
|
||||||
extractBaseUrl,
|
extractBaseUrl,
|
||||||
} from './providerConfig'
|
} from './providerConfig'
|
||||||
@ -330,8 +343,18 @@ const visibleFields = reactive<Record<string, boolean>>({})
|
|||||||
// --- Computed ---
|
// --- Computed ---
|
||||||
const defaultBaseUrl = typeof window !== 'undefined' ? window.location.origin : ''
|
const defaultBaseUrl = typeof window !== 'undefined' ? window.location.origin : ''
|
||||||
|
|
||||||
const stripeWebhookUrl = computed(() =>
|
const providerWebhookHintMap: Record<string, string> = {
|
||||||
form.provider_key === 'stripe' ? defaultBaseUrl + WEBHOOK_PATHS.stripe : '',
|
stripe: 'admin.settings.payment.stripeWebhookHint',
|
||||||
|
airwallex: 'admin.settings.payment.airwallexWebhookHint',
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerWebhookUrl = computed(() => {
|
||||||
|
const path = WEBHOOK_PATHS[form.provider_key]
|
||||||
|
return providerWebhookHintMap[form.provider_key] && path ? defaultBaseUrl + path : ''
|
||||||
|
})
|
||||||
|
|
||||||
|
const providerWebhookHint = computed(() =>
|
||||||
|
providerWebhookHintMap[form.provider_key] || 'admin.settings.payment.stripeWebhookHint',
|
||||||
)
|
)
|
||||||
|
|
||||||
const callbackPaths = computed(() => PROVIDER_CALLBACK_PATHS[form.provider_key] || null)
|
const callbackPaths = computed(() => PROVIDER_CALLBACK_PATHS[form.provider_key] || null)
|
||||||
@ -415,6 +438,14 @@ const paymentGuide = computed<PaymentGuide | null>(() => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (form.provider_key === 'airwallex') {
|
||||||
|
return {
|
||||||
|
summary: t('admin.settings.payment.airwallexGuideSummary'),
|
||||||
|
note: t('admin.settings.payment.airwallexGuideNote'),
|
||||||
|
items: [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return null
|
return null
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -527,9 +558,19 @@ function handleSave() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const clearableConfigKeys = new Set(
|
||||||
|
(PROVIDER_CONFIG_FIELDS[form.provider_key] || [])
|
||||||
|
.filter(field => field.clearable)
|
||||||
|
.map(field => field.key),
|
||||||
|
)
|
||||||
const filteredConfig: Record<string, string> = {}
|
const filteredConfig: Record<string, string> = {}
|
||||||
for (const [k, v] of Object.entries(config)) {
|
for (const [k, v] of Object.entries(config)) {
|
||||||
if (!v || !v.trim()) continue
|
if (!v || !v.trim()) {
|
||||||
|
if (clearableConfigKeys.has(k)) {
|
||||||
|
filteredConfig[k] = ''
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
filteredConfig[k] = v
|
filteredConfig[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -22,11 +22,11 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="flex justify-between">
|
<div class="flex justify-between">
|
||||||
<span class="text-gray-500 dark:text-gray-400">{{ t('payment.orders.amount') }}</span>
|
<span class="text-gray-500 dark:text-gray-400">{{ t('payment.orders.amount') }}</span>
|
||||||
<span class="font-medium text-gray-900 dark:text-white">{{ paidOrder.order_type === 'balance' ? '$' : '¥' }}{{ paidOrder.amount.toFixed(2) }}</span>
|
<span class="font-medium text-gray-900 dark:text-white">{{ paidOrder.order_type === 'balance' ? '$' + paidOrder.amount.toFixed(2) : formatGatewayAmount(paidOrder.amount) }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex justify-between">
|
<div class="flex justify-between">
|
||||||
<span class="text-gray-500 dark:text-gray-400">{{ t('payment.orders.payAmount') }}</span>
|
<span class="text-gray-500 dark:text-gray-400">{{ t('payment.orders.payAmount') }}</span>
|
||||||
<span class="font-medium text-gray-900 dark:text-white">¥{{ paidOrder.pay_amount.toFixed(2) }}</span>
|
<span class="font-medium text-gray-900 dark:text-white">{{ formatGatewayAmount(paidOrder.pay_amount) }}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@ -129,6 +129,7 @@ import { useAppStore } from '@/stores'
|
|||||||
import { paymentAPI } from '@/api/payment'
|
import { paymentAPI } from '@/api/payment'
|
||||||
import { extractI18nErrorMessage } from '@/utils/apiError'
|
import { extractI18nErrorMessage } from '@/utils/apiError'
|
||||||
import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
|
import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
|
||||||
|
import { formatPaymentAmount, normalizePaymentCurrency } from '@/components/payment/currency'
|
||||||
import type { PaymentOrder } from '@/types/payment'
|
import type { PaymentOrder } from '@/types/payment'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import QRCode from 'qrcode'
|
import QRCode from 'qrcode'
|
||||||
@ -142,13 +143,15 @@ const props = defineProps<{
|
|||||||
paymentType: string
|
paymentType: string
|
||||||
payUrl?: string
|
payUrl?: string
|
||||||
orderType?: string
|
orderType?: string
|
||||||
|
currency?: string
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
type PaymentOutcome = 'success' | 'cancelled' | 'expired'
|
type PaymentOutcome = 'success' | 'cancelled' | 'expired'
|
||||||
|
|
||||||
const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
|
const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
|
||||||
|
|
||||||
const { t } = useI18n()
|
const i18n = useI18n()
|
||||||
|
const { t } = i18n
|
||||||
const paymentStore = usePaymentStore()
|
const paymentStore = usePaymentStore()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
|
||||||
@ -157,6 +160,15 @@ const qrUrl = ref('')
|
|||||||
const remainingSeconds = ref(0)
|
const remainingSeconds = ref(0)
|
||||||
const cancelling = ref(false)
|
const cancelling = ref(false)
|
||||||
const paidOrder = ref<PaymentOrder | null>(null)
|
const paidOrder = ref<PaymentOrder | null>(null)
|
||||||
|
const paymentCurrency = computed(() => normalizePaymentCurrency(props.currency))
|
||||||
|
const localeCode = computed(() => {
|
||||||
|
const raw = i18n.locale as unknown
|
||||||
|
if (typeof raw === 'string') return raw
|
||||||
|
if (raw && typeof raw === 'object' && 'value' in raw) {
|
||||||
|
return String((raw as { value?: string }).value || '')
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
})
|
||||||
|
|
||||||
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
|
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
|
||||||
const outcome = ref<PaymentOutcome | null>(null)
|
const outcome = ref<PaymentOutcome | null>(null)
|
||||||
@ -197,6 +209,10 @@ const countdownDisplay = computed(() => {
|
|||||||
return m.toString().padStart(2, '0') + ':' + s.toString().padStart(2, '0')
|
return m.toString().padStart(2, '0') + ':' + s.toString().padStart(2, '0')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
function formatGatewayAmount(value: number): string {
|
||||||
|
return formatPaymentAmount(value, paymentCurrency.value, localeCode.value)
|
||||||
|
}
|
||||||
|
|
||||||
function isSuccessStatus(status: string | null | undefined): boolean {
|
function isSuccessStatus(status: string | null | undefined): boolean {
|
||||||
return status === 'COMPLETED' || status === 'PAID' || status === 'RECHARGING'
|
return status === 'COMPLETED' || status === 'PAID' || status === 'RECHARGING'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,6 +76,7 @@ const PROVIDER_KEY_LABELS: Record<string, string> = {
|
|||||||
alipay: 'admin.settings.payment.providerAlipay',
|
alipay: 'admin.settings.payment.providerAlipay',
|
||||||
wxpay: 'admin.settings.payment.providerWxpay',
|
wxpay: 'admin.settings.payment.providerWxpay',
|
||||||
stripe: 'admin.settings.payment.providerStripe',
|
stripe: 'admin.settings.payment.providerStripe',
|
||||||
|
airwallex: 'admin.settings.payment.providerAirwallex',
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
|
|||||||
@ -2,34 +2,66 @@ import { describe, expect, it, vi } from 'vitest'
|
|||||||
import { mount } from '@vue/test-utils'
|
import { mount } from '@vue/test-utils'
|
||||||
import { nextTick } from 'vue'
|
import { nextTick } from 'vue'
|
||||||
import PaymentProviderDialog from '@/components/payment/PaymentProviderDialog.vue'
|
import PaymentProviderDialog from '@/components/payment/PaymentProviderDialog.vue'
|
||||||
|
import { STRIPE_SDK_API_VERSION } from '@/components/payment/providerConfig'
|
||||||
|
import type { ProviderInstance } from '@/types/payment'
|
||||||
|
|
||||||
const messages: Record<string, string> = {
|
const messages: Record<string, string> = {
|
||||||
'admin.settings.payment.providerConfig': 'Credentials',
|
'admin.settings.payment.providerConfig': 'Credentials',
|
||||||
'admin.settings.payment.paymentGuideTrigger': 'View payment guide',
|
'admin.settings.payment.paymentGuideTrigger': 'View payment guide',
|
||||||
'admin.settings.payment.alipayGuideSummary': 'Desktop prefers QR precreate and falls back to cashier; mobile prefers WAP checkout.',
|
'admin.settings.payment.alipayGuideSummary': 'Desktop prefers QR precreate and falls back to cashier; mobile prefers WAP checkout.',
|
||||||
'admin.settings.payment.wxpayGuideSummary': 'Desktop prefers Native QR; mobile routes to JSAPI or H5 based on browser context.',
|
'admin.settings.payment.wxpayGuideSummary': 'Desktop prefers Native QR; mobile routes to JSAPI or H5 based on browser context.',
|
||||||
|
'admin.settings.payment.airwallexGuideSummary': 'Use Payment Acceptance read/write only.',
|
||||||
|
'admin.settings.payment.stripeWebhookHint': 'Configure Stripe webhook.',
|
||||||
|
'admin.settings.payment.stripeWebhookApiVersionHint': 'Use Stripe API version {version}.',
|
||||||
|
'admin.settings.payment.airwallexWebhookHint': 'Select payment_intent.succeeded and use the latest stable API version.',
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.mock('vue-i18n', () => ({
|
vi.mock('vue-i18n', () => ({
|
||||||
useI18n: () => ({
|
useI18n: () => ({
|
||||||
t: (key: string) => messages[key] ?? key,
|
t: (key: string, params?: Record<string, string>) => {
|
||||||
|
const message = messages[key] ?? key
|
||||||
|
if (!params) return message
|
||||||
|
return Object.entries(params).reduce(
|
||||||
|
(value, [name, replacement]) => value.replaceAll(`{${name}}`, replacement),
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
function mountDialog() {
|
function providerFactory(overrides: Partial<ProviderInstance> = {}): ProviderInstance {
|
||||||
|
return {
|
||||||
|
id: 1,
|
||||||
|
provider_key: 'airwallex',
|
||||||
|
name: 'Airwallex',
|
||||||
|
config: {},
|
||||||
|
supported_types: ['airwallex'],
|
||||||
|
enabled: true,
|
||||||
|
payment_mode: '',
|
||||||
|
refund_enabled: false,
|
||||||
|
allow_user_refund: false,
|
||||||
|
limits: '',
|
||||||
|
sort_order: 0,
|
||||||
|
...overrides,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function mountDialog(options: { editing?: ProviderInstance | null } = {}) {
|
||||||
return mount(PaymentProviderDialog, {
|
return mount(PaymentProviderDialog, {
|
||||||
props: {
|
props: {
|
||||||
show: true,
|
show: true,
|
||||||
saving: false,
|
saving: false,
|
||||||
editing: null,
|
editing: options.editing ?? null,
|
||||||
allKeyOptions: [
|
allKeyOptions: [
|
||||||
{ value: 'alipay', label: 'Alipay' },
|
{ value: 'alipay', label: 'Alipay' },
|
||||||
{ value: 'wxpay', label: 'WeChat Pay' },
|
{ value: 'wxpay', label: 'WeChat Pay' },
|
||||||
{ value: 'stripe', label: 'Stripe' },
|
{ value: 'stripe', label: 'Stripe' },
|
||||||
|
{ value: 'airwallex', label: 'Airwallex' },
|
||||||
],
|
],
|
||||||
enabledKeyOptions: [
|
enabledKeyOptions: [
|
||||||
{ value: 'alipay', label: 'Alipay' },
|
{ value: 'alipay', label: 'Alipay' },
|
||||||
{ value: 'wxpay', label: 'WeChat Pay' },
|
{ value: 'wxpay', label: 'WeChat Pay' },
|
||||||
|
{ value: 'airwallex', label: 'Airwallex' },
|
||||||
],
|
],
|
||||||
allPaymentTypes: [
|
allPaymentTypes: [
|
||||||
{ value: 'alipay', label: 'Alipay' },
|
{ value: 'alipay', label: 'Alipay' },
|
||||||
@ -66,6 +98,7 @@ describe('PaymentProviderDialog payment guide', () => {
|
|||||||
it.each([
|
it.each([
|
||||||
['alipay', 'admin.settings.payment.alipayGuideSummary'],
|
['alipay', 'admin.settings.payment.alipayGuideSummary'],
|
||||||
['wxpay', 'admin.settings.payment.wxpayGuideSummary'],
|
['wxpay', 'admin.settings.payment.wxpayGuideSummary'],
|
||||||
|
['airwallex', 'admin.settings.payment.airwallexGuideSummary'],
|
||||||
])('shows the payment guide summary for %s', async (providerKey, summaryKey) => {
|
])('shows the payment guide summary for %s', async (providerKey, summaryKey) => {
|
||||||
const wrapper = mountDialog()
|
const wrapper = mountDialog()
|
||||||
|
|
||||||
@ -75,4 +108,52 @@ describe('PaymentProviderDialog payment guide', () => {
|
|||||||
expect(wrapper.text()).toContain(messages[summaryKey])
|
expect(wrapper.text()).toContain(messages[summaryKey])
|
||||||
expect(wrapper.find('button[title="View payment guide"]').exists()).toBe(true)
|
expect(wrapper.find('button[title="View payment guide"]').exists()).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('shows Airwallex webhook event and API version guidance with the webhook URL', async () => {
|
||||||
|
const wrapper = mountDialog()
|
||||||
|
|
||||||
|
;(wrapper.vm as unknown as { reset: (key: string) => void }).reset('airwallex')
|
||||||
|
await nextTick()
|
||||||
|
|
||||||
|
expect(wrapper.text()).toContain(messages['admin.settings.payment.airwallexWebhookHint'])
|
||||||
|
expect(wrapper.text()).toContain('/api/v1/payment/webhook/airwallex')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows Stripe webhook API version guidance with the integrated SDK version', async () => {
|
||||||
|
const wrapper = mountDialog()
|
||||||
|
|
||||||
|
;(wrapper.vm as unknown as { reset: (key: string) => void }).reset('stripe')
|
||||||
|
await nextTick()
|
||||||
|
|
||||||
|
expect(wrapper.text()).toContain(messages['admin.settings.payment.stripeWebhookHint'])
|
||||||
|
expect(wrapper.text()).toContain(`Use Stripe API version ${STRIPE_SDK_API_VERSION}.`)
|
||||||
|
expect(wrapper.text()).toContain('/api/v1/payment/webhook/stripe')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('emits an empty Airwallex accountId when the admin clears it', async () => {
|
||||||
|
const provider = providerFactory({
|
||||||
|
config: {
|
||||||
|
clientId: 'cid_123',
|
||||||
|
apiBase: 'https://api.airwallex.com/api/v1',
|
||||||
|
countryCode: 'CN',
|
||||||
|
currency: 'CNY',
|
||||||
|
accountId: 'acct_123',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
const wrapper = mountDialog({ editing: provider })
|
||||||
|
|
||||||
|
;(wrapper.vm as unknown as { loadProvider: (provider: ProviderInstance) => void }).loadProvider(provider)
|
||||||
|
await nextTick()
|
||||||
|
|
||||||
|
const accountIdInput = wrapper
|
||||||
|
.findAll('input[type="text"]')
|
||||||
|
.find(input => (input.element as HTMLInputElement).value === 'acct_123')
|
||||||
|
if (!accountIdInput) throw new Error('accountId input not found')
|
||||||
|
|
||||||
|
await accountIdInput.setValue('')
|
||||||
|
await wrapper.find('form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
const payload = wrapper.emitted('save')?.[0]?.[0] as { config: Record<string, string> }
|
||||||
|
expect(payload.config.accountId).toBe('')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
10
frontend/src/components/payment/__tests__/currency.spec.ts
Normal file
10
frontend/src/components/payment/__tests__/currency.spec.ts
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import { formatPaymentAmount } from '../currency'
|
||||||
|
|
||||||
|
describe('formatPaymentAmount', () => {
|
||||||
|
it('uses the currency default fraction digits', () => {
|
||||||
|
expect(formatPaymentAmount(100, 'JPY', 'en-US')).not.toContain('.00')
|
||||||
|
expect(formatPaymentAmount(100, 'KRW', 'en-US')).not.toContain('.00')
|
||||||
|
expect(formatPaymentAmount(100, 'HKD', 'en-US')).toContain('.00')
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -38,12 +38,14 @@ describe('getVisibleMethods', () => {
|
|||||||
alipay_direct: methodLimit({ single_min: 5 }),
|
alipay_direct: methodLimit({ single_min: 5 }),
|
||||||
wxpay: methodLimit({ single_max: 100 }),
|
wxpay: methodLimit({ single_max: 100 }),
|
||||||
stripe: methodLimit({ fee_rate: 3 }),
|
stripe: methodLimit({ fee_rate: 3 }),
|
||||||
|
airwallex: methodLimit({ single_min: 10 }),
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(visible).toEqual({
|
expect(visible).toEqual({
|
||||||
alipay: methodLimit({ single_min: 5 }),
|
alipay: methodLimit({ single_min: 5 }),
|
||||||
wxpay: methodLimit({ single_max: 100 }),
|
wxpay: methodLimit({ single_max: 100 }),
|
||||||
stripe: methodLimit({ fee_rate: 3 }),
|
stripe: methodLimit({ fee_rate: 3 }),
|
||||||
|
airwallex: methodLimit({ single_min: 10 }),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -104,6 +106,29 @@ describe('decidePaymentLaunch', () => {
|
|||||||
expect(decision.paymentState.orderType).toBe('subscription')
|
expect(decision.paymentState.orderType).toBe('subscription')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('routes Airwallex client secrets through the hosted Airwallex page', () => {
|
||||||
|
const decision = decidePaymentLaunch(createOrderResult({
|
||||||
|
client_secret: 'awx_cs',
|
||||||
|
intent_id: 'int_awx',
|
||||||
|
currency: 'CNY',
|
||||||
|
country_code: 'CN',
|
||||||
|
payment_env: 'demo',
|
||||||
|
out_trade_no: 'sub2_awx',
|
||||||
|
}), {
|
||||||
|
visibleMethod: 'airwallex',
|
||||||
|
orderType: 'balance',
|
||||||
|
isMobile: false,
|
||||||
|
airwallexRouteUrl: '/payment/airwallex?order_id=101',
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(decision.kind).toBe('airwallex_route')
|
||||||
|
expect(decision.paymentState.payUrl).toBe('/payment/airwallex?order_id=101')
|
||||||
|
expect(decision.paymentState.intentId).toBe('int_awx')
|
||||||
|
expect(decision.paymentState.currency).toBe('CNY')
|
||||||
|
expect(decision.paymentState.countryCode).toBe('CN')
|
||||||
|
expect(decision.paymentState.paymentEnv).toBe('demo')
|
||||||
|
})
|
||||||
|
|
||||||
it('keeps hosted redirect metadata for recovery flows', () => {
|
it('keeps hosted redirect metadata for recovery flows', () => {
|
||||||
const decision = decidePaymentLaunch(createOrderResult({
|
const decision = decidePaymentLaunch(createOrderResult({
|
||||||
pay_url: 'https://pay.example.com/session/abc',
|
pay_url: 'https://pay.example.com/session/abc',
|
||||||
@ -248,6 +273,10 @@ describe('readPaymentRecoverySnapshot', () => {
|
|||||||
payUrl: 'https://pay.example.com/session/33',
|
payUrl: 'https://pay.example.com/session/33',
|
||||||
outTradeNo: 'sub2_33',
|
outTradeNo: 'sub2_33',
|
||||||
clientSecret: '',
|
clientSecret: '',
|
||||||
|
intentId: '',
|
||||||
|
currency: '',
|
||||||
|
countryCode: '',
|
||||||
|
paymentEnv: '',
|
||||||
payAmount: 18,
|
payAmount: 18,
|
||||||
orderType: 'balance',
|
orderType: 'balance',
|
||||||
paymentMode: 'popup',
|
paymentMode: 'popup',
|
||||||
@ -273,6 +302,10 @@ describe('readPaymentRecoverySnapshot', () => {
|
|||||||
payUrl: 'https://pay.example.com/session/55',
|
payUrl: 'https://pay.example.com/session/55',
|
||||||
outTradeNo: 'sub2_55',
|
outTradeNo: 'sub2_55',
|
||||||
clientSecret: '',
|
clientSecret: '',
|
||||||
|
intentId: '',
|
||||||
|
currency: '',
|
||||||
|
countryCode: '',
|
||||||
|
paymentEnv: '',
|
||||||
payAmount: 18,
|
payAmount: 18,
|
||||||
orderType: 'balance',
|
orderType: 'balance',
|
||||||
paymentMode: 'popup',
|
paymentMode: 'popup',
|
||||||
@ -317,4 +350,31 @@ describe('readPaymentRecoverySnapshot', () => {
|
|||||||
expect(restored?.orderId).toBe(44)
|
expect(restored?.orderId).toBe(44)
|
||||||
expect(restored?.outTradeNo).toBe('')
|
expect(restored?.outTradeNo).toBe('')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('keeps backward compatibility with snapshots written before Airwallex fields existed', () => {
|
||||||
|
const restored = readPaymentRecoverySnapshot(JSON.stringify({
|
||||||
|
orderId: 45,
|
||||||
|
amount: 28,
|
||||||
|
qrCode: '',
|
||||||
|
expiresAt: '2099-01-01T00:10:00.000Z',
|
||||||
|
paymentType: 'airwallex',
|
||||||
|
payUrl: '/payment/airwallex?order_id=45',
|
||||||
|
outTradeNo: 'sub2_45',
|
||||||
|
clientSecret: 'awx_cs',
|
||||||
|
payAmount: 28,
|
||||||
|
orderType: 'balance',
|
||||||
|
paymentMode: '',
|
||||||
|
resumeToken: 'resume-45',
|
||||||
|
createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
|
||||||
|
}), {
|
||||||
|
now: Date.UTC(2099, 0, 1, 0, 1, 0),
|
||||||
|
resumeToken: 'resume-45',
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(restored?.orderId).toBe(45)
|
||||||
|
expect(restored?.intentId).toBe('')
|
||||||
|
expect(restored?.currency).toBe('')
|
||||||
|
expect(restored?.countryCode).toBe('')
|
||||||
|
expect(restored?.paymentEnv).toBe('')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,20 +1,52 @@
|
|||||||
import { describe, expect, it } from 'vitest'
|
import { describe, expect, it } from 'vitest'
|
||||||
import { PROVIDER_CONFIG_FIELDS } from '@/components/payment/providerConfig'
|
import { PAYMENT_CURRENCY_OPTIONS, PROVIDER_CONFIG_FIELDS } from '@/components/payment/providerConfig'
|
||||||
|
|
||||||
function findField(key: string) {
|
function findField(providerKey: string, key: string) {
|
||||||
const fields = PROVIDER_CONFIG_FIELDS.wxpay || []
|
const fields = PROVIDER_CONFIG_FIELDS[providerKey] || []
|
||||||
return fields.find(field => field.key === key)
|
return fields.find(field => field.key === key)
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('PROVIDER_CONFIG_FIELDS.wxpay', () => {
|
describe('PROVIDER_CONFIG_FIELDS.wxpay', () => {
|
||||||
it('keeps admin form validation aligned with backend-required credentials', () => {
|
it('keeps admin form validation aligned with backend-required credentials', () => {
|
||||||
expect(findField('publicKeyId')?.optional).toBeFalsy()
|
expect(findField('wxpay', 'publicKeyId')?.optional).toBeFalsy()
|
||||||
expect(findField('certSerial')?.optional).toBeFalsy()
|
expect(findField('wxpay', 'certSerial')?.optional).toBeFalsy()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('only keeps the simplified visible credential set in the admin form', () => {
|
it('only keeps the simplified visible credential set in the admin form', () => {
|
||||||
expect(findField('mpAppId')).toBeUndefined()
|
expect(findField('wxpay', 'mpAppId')).toBeUndefined()
|
||||||
expect(findField('h5AppName')).toBeUndefined()
|
expect(findField('wxpay', 'h5AppName')).toBeUndefined()
|
||||||
expect(findField('h5AppUrl')).toBeUndefined()
|
expect(findField('wxpay', 'h5AppUrl')).toBeUndefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('PROVIDER_CONFIG_FIELDS.airwallex', () => {
|
||||||
|
it('adds currency config with CNY as the default', () => {
|
||||||
|
const currency = findField('airwallex', 'currency')
|
||||||
|
|
||||||
|
expect(currency?.defaultValue).toBe('CNY')
|
||||||
|
expect(currency?.hintKey).toBe('admin.settings.payment.field_paymentCurrencyHint')
|
||||||
|
expect(currency?.options).toBe(PAYMENT_CURRENCY_OPTIONS)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('marks accountId as optional and explains when it can be left blank', () => {
|
||||||
|
const accountId = findField('airwallex', 'accountId')
|
||||||
|
|
||||||
|
expect(accountId?.optional).toBe(true)
|
||||||
|
expect(accountId?.clearable).toBe(true)
|
||||||
|
expect(accountId?.hintKey).toBe('admin.settings.payment.field_accountIdHint')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('explains that apiBase must match the Airwallex key environment', () => {
|
||||||
|
expect(findField('airwallex', 'apiBase')?.hintKey).toBe('admin.settings.payment.field_airwallexApiBaseHint')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('PROVIDER_CONFIG_FIELDS.stripe', () => {
|
||||||
|
it('adds currency config with CNY as the default', () => {
|
||||||
|
const currency = findField('stripe', 'currency')
|
||||||
|
|
||||||
|
expect(currency?.defaultValue).toBe('CNY')
|
||||||
|
expect(currency?.hintKey).toBe('admin.settings.payment.field_paymentCurrencyHint')
|
||||||
|
expect(currency?.options).toBe(PAYMENT_CURRENCY_OPTIONS)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
33
frontend/src/components/payment/currency.ts
Normal file
33
frontend/src/components/payment/currency.ts
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
export const DEFAULT_PAYMENT_CURRENCY = 'CNY'
|
||||||
|
|
||||||
|
export function normalizePaymentCurrency(currency?: string | null): string {
|
||||||
|
const normalized = String(currency || '').trim().toUpperCase()
|
||||||
|
return /^[A-Z]{3}$/.test(normalized) ? normalized : DEFAULT_PAYMENT_CURRENCY
|
||||||
|
}
|
||||||
|
|
||||||
|
function paymentCurrencyFractionDigits(currency: string): number {
|
||||||
|
try {
|
||||||
|
return new Intl.NumberFormat(undefined, {
|
||||||
|
style: 'currency',
|
||||||
|
currency,
|
||||||
|
}).resolvedOptions().maximumFractionDigits ?? 2
|
||||||
|
} catch {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function formatPaymentAmount(amount: number, currency?: string | null, locale?: string): string {
|
||||||
|
const normalized = normalizePaymentCurrency(currency)
|
||||||
|
const fractionDigits = paymentCurrencyFractionDigits(normalized)
|
||||||
|
try {
|
||||||
|
return new Intl.NumberFormat(locale || undefined, {
|
||||||
|
style: 'currency',
|
||||||
|
currency: normalized,
|
||||||
|
currencyDisplay: 'narrowSymbol',
|
||||||
|
minimumFractionDigits: fractionDigits,
|
||||||
|
maximumFractionDigits: fractionDigits,
|
||||||
|
}).format(Number.isFinite(amount) ? amount : 0)
|
||||||
|
} catch {
|
||||||
|
return `${normalized} ${(Number.isFinite(amount) ? amount : 0).toFixed(fractionDigits)}`
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,15 +15,17 @@ const VISIBLE_METHOD_ALIASES = {
|
|||||||
wxpay: 'wxpay',
|
wxpay: 'wxpay',
|
||||||
wxpay_direct: 'wxpay',
|
wxpay_direct: 'wxpay',
|
||||||
stripe: 'stripe',
|
stripe: 'stripe',
|
||||||
|
airwallex: 'airwallex',
|
||||||
} as const
|
} as const
|
||||||
|
|
||||||
export type VisiblePaymentMethod = 'alipay' | 'wxpay' | 'stripe'
|
export type VisiblePaymentMethod = 'alipay' | 'wxpay' | 'stripe' | 'airwallex'
|
||||||
export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
|
export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
|
||||||
export type PaymentLaunchKind =
|
export type PaymentLaunchKind =
|
||||||
| 'qr_waiting'
|
| 'qr_waiting'
|
||||||
| 'redirect_waiting'
|
| 'redirect_waiting'
|
||||||
| 'stripe_popup'
|
| 'stripe_popup'
|
||||||
| 'stripe_route'
|
| 'stripe_route'
|
||||||
|
| 'airwallex_route'
|
||||||
| 'wechat_oauth'
|
| 'wechat_oauth'
|
||||||
| 'wechat_jsapi'
|
| 'wechat_jsapi'
|
||||||
| 'unhandled'
|
| 'unhandled'
|
||||||
@ -37,6 +39,10 @@ export interface PaymentRecoverySnapshot {
|
|||||||
payUrl: string
|
payUrl: string
|
||||||
outTradeNo: string
|
outTradeNo: string
|
||||||
clientSecret: string
|
clientSecret: string
|
||||||
|
intentId: string
|
||||||
|
currency: string
|
||||||
|
countryCode: string
|
||||||
|
paymentEnv: string
|
||||||
payAmount: number
|
payAmount: number
|
||||||
orderType: OrderType | ''
|
orderType: OrderType | ''
|
||||||
paymentMode: string
|
paymentMode: string
|
||||||
@ -52,6 +58,7 @@ export interface PaymentLaunchContext {
|
|||||||
now?: number
|
now?: number
|
||||||
stripePopupUrl?: string
|
stripePopupUrl?: string
|
||||||
stripeRouteUrl?: string
|
stripeRouteUrl?: string
|
||||||
|
airwallexRouteUrl?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface PaymentLaunchDecision {
|
export interface PaymentLaunchDecision {
|
||||||
@ -138,12 +145,24 @@ export function decidePaymentLaunch(
|
|||||||
payUrl: result.pay_url || '',
|
payUrl: result.pay_url || '',
|
||||||
outTradeNo: result.out_trade_no || '',
|
outTradeNo: result.out_trade_no || '',
|
||||||
clientSecret: result.client_secret || '',
|
clientSecret: result.client_secret || '',
|
||||||
|
intentId: result.intent_id || '',
|
||||||
|
currency: result.currency || '',
|
||||||
|
countryCode: result.country_code || '',
|
||||||
|
paymentEnv: result.payment_env || '',
|
||||||
payAmount: result.pay_amount,
|
payAmount: result.pay_amount,
|
||||||
orderType: context.orderType,
|
orderType: context.orderType,
|
||||||
paymentMode: (result.payment_mode || '').trim(),
|
paymentMode: (result.payment_mode || '').trim(),
|
||||||
resumeToken: result.resume_token || '',
|
resumeToken: result.resume_token || '',
|
||||||
}, context.now)
|
}, context.now)
|
||||||
|
|
||||||
|
if (visibleMethod === 'airwallex' && baseState.clientSecret && baseState.intentId) {
|
||||||
|
if (!context.airwallexRouteUrl) {
|
||||||
|
return { kind: 'unhandled', paymentState: baseState, recovery: baseState }
|
||||||
|
}
|
||||||
|
const paymentState = { ...baseState, payUrl: context.airwallexRouteUrl || '' }
|
||||||
|
return { kind: 'airwallex_route', paymentState, recovery: paymentState }
|
||||||
|
}
|
||||||
|
|
||||||
if (baseState.clientSecret) {
|
if (baseState.clientSecret) {
|
||||||
// visibleMethod === 'stripe' means the user clicked the dedicated Stripe button
|
// visibleMethod === 'stripe' means the user clicked the dedicated Stripe button
|
||||||
// and should land on the full Payment Element to choose a sub-method themselves.
|
// and should land on the full Payment Element to choose a sub-method themselves.
|
||||||
@ -239,6 +258,10 @@ export function readPaymentRecoverySnapshot(
|
|||||||
|| typeof parsed.payUrl !== 'string'
|
|| typeof parsed.payUrl !== 'string'
|
||||||
|| (parsed.outTradeNo != null && typeof parsed.outTradeNo !== 'string')
|
|| (parsed.outTradeNo != null && typeof parsed.outTradeNo !== 'string')
|
||||||
|| typeof parsed.clientSecret !== 'string'
|
|| typeof parsed.clientSecret !== 'string'
|
||||||
|
|| (parsed.intentId != null && typeof parsed.intentId !== 'string')
|
||||||
|
|| (parsed.currency != null && typeof parsed.currency !== 'string')
|
||||||
|
|| (parsed.countryCode != null && typeof parsed.countryCode !== 'string')
|
||||||
|
|| (parsed.paymentEnv != null && typeof parsed.paymentEnv !== 'string')
|
||||||
|| typeof parsed.payAmount !== 'number'
|
|| typeof parsed.payAmount !== 'number'
|
||||||
|| typeof parsed.paymentMode !== 'string'
|
|| typeof parsed.paymentMode !== 'string'
|
||||||
|| typeof parsed.resumeToken !== 'string'
|
|| typeof parsed.resumeToken !== 'string'
|
||||||
@ -265,6 +288,10 @@ export function readPaymentRecoverySnapshot(
|
|||||||
payUrl: parsed.payUrl,
|
payUrl: parsed.payUrl,
|
||||||
outTradeNo: parsed.outTradeNo || '',
|
outTradeNo: parsed.outTradeNo || '',
|
||||||
clientSecret: parsed.clientSecret,
|
clientSecret: parsed.clientSecret,
|
||||||
|
intentId: parsed.intentId || '',
|
||||||
|
currency: parsed.currency || '',
|
||||||
|
countryCode: parsed.countryCode || '',
|
||||||
|
paymentEnv: parsed.paymentEnv || '',
|
||||||
payAmount: parsed.payAmount,
|
payAmount: parsed.payAmount,
|
||||||
orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
|
orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
|
||||||
paymentMode: parsed.paymentMode,
|
paymentMode: parsed.paymentMode,
|
||||||
|
|||||||
@ -9,12 +9,16 @@ export interface ConfigFieldDef {
|
|||||||
label: string
|
label: string
|
||||||
sensitive: boolean
|
sensitive: boolean
|
||||||
optional?: boolean
|
optional?: boolean
|
||||||
|
clearable?: boolean
|
||||||
defaultValue?: string
|
defaultValue?: string
|
||||||
|
hintKey?: string
|
||||||
|
options?: TypeOption[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TypeOption {
|
export interface TypeOption {
|
||||||
value: string
|
value: string
|
||||||
label: string
|
label: string
|
||||||
|
[key: string]: unknown
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Callback URL paths for a provider. */
|
/** Callback URL paths for a provider. */
|
||||||
@ -31,18 +35,36 @@ export const PROVIDER_SUPPORTED_TYPES: Record<string, string[]> = {
|
|||||||
alipay: ['alipay'],
|
alipay: ['alipay'],
|
||||||
wxpay: ['wxpay'],
|
wxpay: ['wxpay'],
|
||||||
stripe: ['card', 'alipay', 'wxpay', 'link'],
|
stripe: ['card', 'alipay', 'wxpay', 'link'],
|
||||||
|
airwallex: ['airwallex'],
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Available payment modes for EasyPay providers. */
|
/** Available payment modes for EasyPay providers. */
|
||||||
export const EASYPAY_PAYMENT_MODES = ['qrcode', 'popup'] as const
|
export const EASYPAY_PAYMENT_MODES = ['qrcode', 'popup'] as const
|
||||||
|
|
||||||
/** Fixed display order for user-facing payment methods */
|
/** Fixed display order for user-facing payment methods */
|
||||||
export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct', 'stripe'] as const
|
export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct', 'stripe', 'airwallex'] as const
|
||||||
|
|
||||||
/** Payment mode constants */
|
/** Payment mode constants */
|
||||||
export const PAYMENT_MODE_QRCODE = 'qrcode'
|
export const PAYMENT_MODE_QRCODE = 'qrcode'
|
||||||
export const PAYMENT_MODE_POPUP = 'popup'
|
export const PAYMENT_MODE_POPUP = 'popup'
|
||||||
|
|
||||||
|
export const PAYMENT_CURRENCY_OPTIONS: TypeOption[] = [
|
||||||
|
{ value: 'CNY', label: 'CNY' },
|
||||||
|
{ value: 'HKD', label: 'HKD' },
|
||||||
|
{ value: 'USD', label: 'USD' },
|
||||||
|
{ value: 'EUR', label: 'EUR' },
|
||||||
|
{ value: 'GBP', label: 'GBP' },
|
||||||
|
{ value: 'AUD', label: 'AUD' },
|
||||||
|
{ value: 'CAD', label: 'CAD' },
|
||||||
|
{ value: 'SGD', label: 'SGD' },
|
||||||
|
{ value: 'JPY', label: 'JPY' },
|
||||||
|
{ value: 'KRW', label: 'KRW' },
|
||||||
|
{ value: 'NZD', label: 'NZD' },
|
||||||
|
]
|
||||||
|
|
||||||
|
// 与后端当前集成的 stripe-go v85.0.0 的 stripe.APIVersion 保持一致。
|
||||||
|
export const STRIPE_SDK_API_VERSION = '2026-03-25.dahlia'
|
||||||
|
|
||||||
/** Preferred popup size for payment gateways. Alipay's standard checkout
|
/** Preferred popup size for payment gateways. Alipay's standard checkout
|
||||||
* (QR + account login panel) needs ~1200×900 to render without any scrolling. */
|
* (QR + account login panel) needs ~1200×900 to render without any scrolling. */
|
||||||
const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
|
const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
|
||||||
@ -68,6 +90,7 @@ export const WEBHOOK_PATHS: Record<string, string> = {
|
|||||||
alipay: '/api/v1/payment/webhook/alipay',
|
alipay: '/api/v1/payment/webhook/alipay',
|
||||||
wxpay: '/api/v1/payment/webhook/wxpay',
|
wxpay: '/api/v1/payment/webhook/wxpay',
|
||||||
stripe: '/api/v1/payment/webhook/stripe',
|
stripe: '/api/v1/payment/webhook/stripe',
|
||||||
|
airwallex: '/api/v1/payment/webhook/airwallex',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const RETURN_PATH = '/payment/result'
|
export const RETURN_PATH = '/payment/result'
|
||||||
@ -77,7 +100,8 @@ export const PROVIDER_CALLBACK_PATHS: Record<string, CallbackPaths> = {
|
|||||||
easypay: { notifyUrl: WEBHOOK_PATHS.easypay, returnUrl: RETURN_PATH },
|
easypay: { notifyUrl: WEBHOOK_PATHS.easypay, returnUrl: RETURN_PATH },
|
||||||
alipay: { notifyUrl: WEBHOOK_PATHS.alipay, returnUrl: RETURN_PATH },
|
alipay: { notifyUrl: WEBHOOK_PATHS.alipay, returnUrl: RETURN_PATH },
|
||||||
wxpay: { notifyUrl: WEBHOOK_PATHS.wxpay },
|
wxpay: { notifyUrl: WEBHOOK_PATHS.wxpay },
|
||||||
// stripe: no callback URL config needed (webhook is separate)
|
// stripe: 不需要回调 URL 配置,Webhook 单独配置。
|
||||||
|
// airwallex: 不需要回调 URL 配置,Webhook 在空中云汇后台配置。
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Per-provider config fields (excludes notifyUrl/returnUrl which are handled separately). */
|
/** Per-provider config fields (excludes notifyUrl/returnUrl which are handled separately). */
|
||||||
@ -107,6 +131,16 @@ export const PROVIDER_CONFIG_FIELDS: Record<string, ConfigFieldDef[]> = {
|
|||||||
{ key: 'secretKey', label: '', sensitive: true },
|
{ key: 'secretKey', label: '', sensitive: true },
|
||||||
{ key: 'publishableKey', label: '', sensitive: false },
|
{ key: 'publishableKey', label: '', sensitive: false },
|
||||||
{ key: 'webhookSecret', label: '', sensitive: true },
|
{ key: 'webhookSecret', label: '', sensitive: true },
|
||||||
|
{ key: 'currency', label: '', sensitive: false, defaultValue: 'CNY', hintKey: 'admin.settings.payment.field_paymentCurrencyHint', options: PAYMENT_CURRENCY_OPTIONS },
|
||||||
|
],
|
||||||
|
airwallex: [
|
||||||
|
{ key: 'clientId', label: '', sensitive: false },
|
||||||
|
{ key: 'apiKey', label: '', sensitive: true },
|
||||||
|
{ key: 'webhookSecret', label: '', sensitive: true },
|
||||||
|
{ key: 'apiBase', label: '', sensitive: false, defaultValue: 'https://api.airwallex.com/api/v1', hintKey: 'admin.settings.payment.field_airwallexApiBaseHint' },
|
||||||
|
{ key: 'countryCode', label: '', sensitive: false, defaultValue: 'CN' },
|
||||||
|
{ key: 'currency', label: '', sensitive: false, defaultValue: 'CNY', hintKey: 'admin.settings.payment.field_paymentCurrencyHint', options: PAYMENT_CURRENCY_OPTIONS },
|
||||||
|
{ key: 'accountId', label: '', sensitive: false, optional: true, clearable: true, hintKey: 'admin.settings.payment.field_accountIdHint' },
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5396,6 +5396,11 @@ export default {
|
|||||||
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
|
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
|
||||||
anthropicCacheTTL1hInjection: 'Anthropic Cache TTL Injection',
|
anthropicCacheTTL1hInjection: 'Anthropic Cache TTL Injection',
|
||||||
anthropicCacheTTL1hInjectionHint: 'When enabled, existing ephemeral cache_control blocks in Anthropic OAuth/Setup Token request bodies are forced to 1h; response usage is billed back as 5m by default, with account-level TTL billing override taking priority.',
|
anthropicCacheTTL1hInjectionHint: 'When enabled, existing ephemeral cache_control blocks in Anthropic OAuth/Setup Token request bodies are forced to 1h; response usage is billed back as 5m by default, with account-level TTL billing override taking priority.',
|
||||||
|
rewriteMessageCacheControl: 'Rewrite Message Cache Breakpoints',
|
||||||
|
rewriteMessageCacheControlHint: 'Default off: preserve client cache_control on message content blocks. When enabled, client breakpoints are stripped and proxy breakpoints are injected for clients that do not manage caching themselves.',
|
||||||
|
antigravityUserAgentVersion: 'Antigravity UA Version',
|
||||||
|
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||||
|
antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.',
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search Emulation',
|
title: 'Web Search Emulation',
|
||||||
@ -5575,6 +5580,7 @@ export default {
|
|||||||
providerAlipay: 'Alipay (Direct)',
|
providerAlipay: 'Alipay (Direct)',
|
||||||
providerWxpay: 'WeChat Pay (Direct)',
|
providerWxpay: 'WeChat Pay (Direct)',
|
||||||
providerStripe: 'Stripe',
|
providerStripe: 'Stripe',
|
||||||
|
providerAirwallex: 'Airwallex',
|
||||||
typeDisabled: 'type disabled',
|
typeDisabled: 'type disabled',
|
||||||
enableTypesFirst: 'Enable at least one payment type above first',
|
enableTypesFirst: 'Enable at least one payment type above first',
|
||||||
easypayRedirect: 'Redirect',
|
easypayRedirect: 'Redirect',
|
||||||
@ -5601,12 +5607,24 @@ export default {
|
|||||||
wxpayConfigHint: 'WeChat Pay usually only needs App ID. Fill MP App ID, H5 App Name, and H5 App URL only when your Official Account or H5 flow specifically requires them.',
|
wxpayConfigHint: 'WeChat Pay usually only needs App ID. Fill MP App ID, H5 App Name, and H5 App URL only when your Official Account or H5 flow specifically requires them.',
|
||||||
wxpayAdvancedOptions: 'WeChat Pay Advanced Options',
|
wxpayAdvancedOptions: 'WeChat Pay Advanced Options',
|
||||||
field_secretKey: 'Secret Key',
|
field_secretKey: 'Secret Key',
|
||||||
|
field_clientId: 'Client ID',
|
||||||
|
field_apiKey: 'API Key',
|
||||||
field_publishableKey: 'Publishable Key',
|
field_publishableKey: 'Publishable Key',
|
||||||
field_webhookSecret: 'Webhook Secret',
|
field_webhookSecret: 'Webhook Secret',
|
||||||
|
field_countryCode: 'Country/region code',
|
||||||
|
field_currency: 'Payment currency',
|
||||||
|
field_accountId: 'Airwallex Account ID',
|
||||||
|
field_airwallexApiBaseHint: 'Must match the API key environment: use https://api-demo.airwallex.com/api/v1 for sandbox/demo keys, and https://api.airwallex.com/api/v1 for production keys. Mixed environments return credentials_invalid / Access Denied.',
|
||||||
|
field_paymentCurrencyHint: 'Default is CNY. Stripe and Airwallex can choose HKD, USD, or another listed currency supported by the account; WeChat Pay, Alipay, and EasyPay remain CNY.',
|
||||||
|
field_accountIdHint: 'Leave this empty unless you use multiple accounts, an organization-level key, or connected-account payments. A single-account scoped API key uses the selected account by default.',
|
||||||
field_cid: 'Channel ID',
|
field_cid: 'Channel ID',
|
||||||
field_cidAlipay: 'Alipay Channel ID',
|
field_cidAlipay: 'Alipay Channel ID',
|
||||||
field_cidWxpay: 'WeChat Channel ID',
|
field_cidWxpay: 'WeChat Channel ID',
|
||||||
stripeWebhookHint: 'Configure the following URL as a Webhook endpoint in Stripe Dashboard:',
|
stripeWebhookHint: 'Configure the following URL as a Webhook endpoint in Stripe Dashboard:',
|
||||||
|
stripeWebhookApiVersionHint: 'Set this Webhook endpoint API version to match the integrated Stripe SDK. Recommended: {version}. A mismatch can cause webhook parsing errors.',
|
||||||
|
airwallexWebhookHint: 'Configure the following URL as a Webhook endpoint in Airwallex. Select at least Payment Intent -> Succeeded (payment_intent.succeeded), preferably also Payment Intent -> Cancelled (payment_intent.cancelled). Use the account default or latest stable API version.',
|
||||||
|
airwallexGuideSummary: 'When creating an Airwallex scoped API key, select Read and Write for Payment Acceptance under account-level permissions.',
|
||||||
|
airwallexGuideNote: 'Do not grant unrelated permissions such as Spend, Payouts, Transfers, Funds Splits, or POS Terminals unless you explicitly need them. For webhooks, select at least payment_intent.succeeded, preferably also payment_intent.cancelled, and use the account default or latest stable API version.',
|
||||||
limitsTitle: 'Limits',
|
limitsTitle: 'Limits',
|
||||||
limitSingleMin: 'Min per order',
|
limitSingleMin: 'Min per order',
|
||||||
limitSingleMax: 'Max per order',
|
limitSingleMax: 'Max per order',
|
||||||
@ -6529,6 +6547,7 @@ export default {
|
|||||||
alipay: 'Alipay',
|
alipay: 'Alipay',
|
||||||
wxpay: 'WeChat Pay',
|
wxpay: 'WeChat Pay',
|
||||||
stripe: 'Stripe',
|
stripe: 'Stripe',
|
||||||
|
airwallex: 'Airwallex',
|
||||||
card: 'Card',
|
card: 'Card',
|
||||||
link: 'Link',
|
link: 'Link',
|
||||||
alipay_direct: 'Alipay (Direct)',
|
alipay_direct: 'Alipay (Direct)',
|
||||||
@ -6614,6 +6633,8 @@ export default {
|
|||||||
stripeLoadFailed: 'Failed to load payment component. Please refresh and try again.',
|
stripeLoadFailed: 'Failed to load payment component. Please refresh and try again.',
|
||||||
stripeMissingParams: 'Missing order ID or client secret',
|
stripeMissingParams: 'Missing order ID or client secret',
|
||||||
stripeNotConfigured: 'Stripe is not configured',
|
stripeNotConfigured: 'Stripe is not configured',
|
||||||
|
airwallexLoadFailed: 'Failed to load Airwallex payment component. Please refresh and try again.',
|
||||||
|
airwallexMissingParams: 'Missing Airwallex payment parameters',
|
||||||
errors: {
|
errors: {
|
||||||
tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
|
tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
|
||||||
cancelRateLimited: 'Too many cancellations. Please try again later.',
|
cancelRateLimited: 'Too many cancellations. Please try again later.',
|
||||||
@ -6659,6 +6680,7 @@ export default {
|
|||||||
REFUND_AMOUNT_EXCEEDED: 'Refund amount exceeds the recharge amount.',
|
REFUND_AMOUNT_EXCEEDED: 'Refund amount exceeds the recharge amount.',
|
||||||
REFUND_FAILED: 'Refund failed.',
|
REFUND_FAILED: 'Refund failed.',
|
||||||
},
|
},
|
||||||
|
airwallexPay: 'Airwallex Payment',
|
||||||
stripePay: 'Pay Now',
|
stripePay: 'Pay Now',
|
||||||
stripeSuccessProcessing: 'Payment successful, processing your order...',
|
stripeSuccessProcessing: 'Payment successful, processing your order...',
|
||||||
stripePopup: {
|
stripePopup: {
|
||||||
|
|||||||
@ -5555,6 +5555,11 @@ export default {
|
|||||||
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
|
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
|
||||||
anthropicCacheTTL1hInjection: 'Anthropic 缓存 TTL 注入',
|
anthropicCacheTTL1hInjection: 'Anthropic 缓存 TTL 注入',
|
||||||
anthropicCacheTTL1hInjectionHint: '开启后,对 Anthropic OAuth/Setup Token 请求体中已有的 ephemeral 缓存块强制写入 1h;响应 usage 默认按 5m 回写计费,账号级 TTL 计费设置优先。',
|
anthropicCacheTTL1hInjectionHint: '开启后,对 Anthropic OAuth/Setup Token 请求体中已有的 ephemeral 缓存块强制写入 1h;响应 usage 默认按 5m 回写计费,账号级 TTL 计费设置优先。',
|
||||||
|
rewriteMessageCacheControl: '改写消息缓存断点',
|
||||||
|
rewriteMessageCacheControlHint: '默认关闭,保留客户端在 messages 内容块中的 cache_control。开启后会清除客户端断点并注入代理断点,适合不自行管理缓存策略的客户端。',
|
||||||
|
antigravityUserAgentVersion: 'Antigravity UA 版本',
|
||||||
|
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||||
|
antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。',
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search 模拟',
|
title: 'Web Search 模拟',
|
||||||
@ -5736,6 +5741,7 @@ export default {
|
|||||||
providerAlipay: '支付宝官方',
|
providerAlipay: '支付宝官方',
|
||||||
providerWxpay: '微信官方',
|
providerWxpay: '微信官方',
|
||||||
providerStripe: 'Stripe',
|
providerStripe: 'Stripe',
|
||||||
|
providerAirwallex: 'Airwallex',
|
||||||
typeDisabled: '类型已禁用',
|
typeDisabled: '类型已禁用',
|
||||||
enableTypesFirst: '请先在上方启用至少一种服务商',
|
enableTypesFirst: '请先在上方启用至少一种服务商',
|
||||||
easypayRedirect: '跳转',
|
easypayRedirect: '跳转',
|
||||||
@ -5762,12 +5768,24 @@ export default {
|
|||||||
wxpayConfigHint: '微信支付通常只需要填写 App ID。公众号 App ID、H5 应用名称、H5 应用地址仅在公众号支付或 H5 场景有特殊要求时再填写。',
|
wxpayConfigHint: '微信支付通常只需要填写 App ID。公众号 App ID、H5 应用名称、H5 应用地址仅在公众号支付或 H5 场景有特殊要求时再填写。',
|
||||||
wxpayAdvancedOptions: '微信支付高级可选项',
|
wxpayAdvancedOptions: '微信支付高级可选项',
|
||||||
field_secretKey: '密钥',
|
field_secretKey: '密钥',
|
||||||
|
field_clientId: 'Client ID',
|
||||||
|
field_apiKey: 'API Key',
|
||||||
field_publishableKey: '公开密钥',
|
field_publishableKey: '公开密钥',
|
||||||
field_webhookSecret: 'Webhook 密钥',
|
field_webhookSecret: 'Webhook 密钥',
|
||||||
|
field_countryCode: '国家/地区代码',
|
||||||
|
field_currency: '支付币种',
|
||||||
|
field_accountId: 'Airwallex 账户 ID',
|
||||||
|
field_airwallexApiBaseHint: '必须和 API Key 所属环境一致:沙箱/测试密钥使用 https://api-demo.airwallex.com/api/v1,生产密钥使用 https://api.airwallex.com/api/v1。环境混用会返回 credentials_invalid / Access Denied。',
|
||||||
|
field_paymentCurrencyHint: '默认 CNY。Stripe 和 Airwallex 可按账户支持从下拉项选择 HKD、USD 等币种;微信、支付宝、易支付仍按 CNY。',
|
||||||
|
field_accountIdHint: '不涉及多账户、组织级密钥或连接账户收款时可以不填;单账户 Scoped API Key 会默认使用所选账户。',
|
||||||
field_cid: '支付渠道 ID',
|
field_cid: '支付渠道 ID',
|
||||||
field_cidAlipay: '支付宝渠道 ID',
|
field_cidAlipay: '支付宝渠道 ID',
|
||||||
field_cidWxpay: '微信渠道 ID',
|
field_cidWxpay: '微信渠道 ID',
|
||||||
stripeWebhookHint: '请在 Stripe Dashboard 中将以下地址配置为 Webhook 端点:',
|
stripeWebhookHint: '请在 Stripe Dashboard 中将以下地址配置为 Webhook 端点:',
|
||||||
|
stripeWebhookApiVersionHint: 'Webhook 端点的 API 版本请与当前集成的 Stripe SDK 对齐,建议选择 {version};版本不一致可能导致回调事件解析失败。',
|
||||||
|
airwallexWebhookHint: '请在 Airwallex 后台将以下地址配置为 Webhook 端点;事件至少选择 Payment Intent -> Succeeded(payment_intent.succeeded),建议同时选择 Payment Intent -> Cancelled(payment_intent.cancelled);API version 选择账户默认或最新稳定版本。',
|
||||||
|
airwallexGuideSummary: '创建 Airwallex Scoped API 密钥时,建议只在账户级权限中为 Payment Acceptance 勾选读取和写入。',
|
||||||
|
airwallexGuideNote: '不需要勾选 Spend、Payouts、Transfers、Funds Splits、POS 终端等与在线收款无关的权限。Webhook 事件至少选择 payment_intent.succeeded,建议同时选择 payment_intent.cancelled;API version 选择账户默认或最新稳定版本。',
|
||||||
limitsTitle: '限额配置',
|
limitsTitle: '限额配置',
|
||||||
limitSingleMin: '单笔最低',
|
limitSingleMin: '单笔最低',
|
||||||
limitSingleMax: '单笔最高',
|
limitSingleMax: '单笔最高',
|
||||||
@ -6714,6 +6732,7 @@ export default {
|
|||||||
alipay: '支付宝',
|
alipay: '支付宝',
|
||||||
wxpay: '微信支付',
|
wxpay: '微信支付',
|
||||||
stripe: 'Stripe',
|
stripe: 'Stripe',
|
||||||
|
airwallex: 'Airwallex',
|
||||||
card: '银行卡',
|
card: '银行卡',
|
||||||
link: 'Link',
|
link: 'Link',
|
||||||
alipay_direct: '支付宝(直连)',
|
alipay_direct: '支付宝(直连)',
|
||||||
@ -6799,6 +6818,8 @@ export default {
|
|||||||
stripeLoadFailed: '支付组件加载失败,请刷新页面重试',
|
stripeLoadFailed: '支付组件加载失败,请刷新页面重试',
|
||||||
stripeMissingParams: '缺少订单ID或支付密钥',
|
stripeMissingParams: '缺少订单ID或支付密钥',
|
||||||
stripeNotConfigured: 'Stripe 未配置',
|
stripeNotConfigured: 'Stripe 未配置',
|
||||||
|
airwallexLoadFailed: 'Airwallex 支付组件加载失败,请刷新页面重试',
|
||||||
|
airwallexMissingParams: '缺少 Airwallex 支付参数',
|
||||||
errors: {
|
errors: {
|
||||||
tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
|
tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
|
||||||
cancelRateLimited: '取消订单过于频繁,请稍后再试',
|
cancelRateLimited: '取消订单过于频繁,请稍后再试',
|
||||||
@ -6844,6 +6865,7 @@ export default {
|
|||||||
REFUND_AMOUNT_EXCEEDED: '退款金额超过充值金额',
|
REFUND_AMOUNT_EXCEEDED: '退款金额超过充值金额',
|
||||||
REFUND_FAILED: '退款失败',
|
REFUND_FAILED: '退款失败',
|
||||||
},
|
},
|
||||||
|
airwallexPay: 'Airwallex 支付',
|
||||||
stripePay: '立即支付',
|
stripePay: '立即支付',
|
||||||
stripeSuccessProcessing: '支付成功,正在处理订单...',
|
stripeSuccessProcessing: '支付成功,正在处理订单...',
|
||||||
stripePopup: {
|
stripePopup: {
|
||||||
|
|||||||
@ -316,6 +316,18 @@ const routes: RouteRecordRaw[] = [
|
|||||||
requiresPayment: false
|
requiresPayment: false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: '/payment/airwallex',
|
||||||
|
name: 'AirwallexPayment',
|
||||||
|
component: () => import('@/views/user/AirwallexPaymentView.vue'),
|
||||||
|
meta: {
|
||||||
|
requiresAuth: false,
|
||||||
|
requiresAdmin: false,
|
||||||
|
title: 'Airwallex Payment',
|
||||||
|
titleKey: 'payment.airwallexPay',
|
||||||
|
requiresPayment: false
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
path: '/payment/stripe-popup',
|
path: '/payment/stripe-popup',
|
||||||
name: 'StripePopup',
|
name: 'StripePopup',
|
||||||
@ -680,7 +692,7 @@ let authInitialized = false
|
|||||||
const navigationLoading = useNavigationLoadingState()
|
const navigationLoading = useNavigationLoadingState()
|
||||||
// 延迟初始化预加载,传入 router 实例
|
// 延迟初始化预加载,传入 router 实例
|
||||||
let routePrefetch: ReturnType<typeof useRoutePrefetch> | null = null
|
let routePrefetch: ReturnType<typeof useRoutePrefetch> | null = null
|
||||||
const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup', '/payment/result', '/legal']
|
const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup', '/payment/result', '/payment/airwallex', '/legal']
|
||||||
const BACKEND_MODE_CALLBACK_PATHS = [
|
const BACKEND_MODE_CALLBACK_PATHS = [
|
||||||
'/auth/callback',
|
'/auth/callback',
|
||||||
'/auth/linuxdo/callback',
|
'/auth/linuxdo/callback',
|
||||||
|
|||||||
@ -121,6 +121,13 @@
|
|||||||
@apply dark:hover:bg-[#635bff];
|
@apply dark:hover:bg-[#635bff];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.btn-airwallex {
|
||||||
|
@apply bg-[#14171A] text-white shadow-md shadow-emerald-500/20;
|
||||||
|
@apply hover:bg-[#20252a] hover:shadow-lg hover:shadow-emerald-500/25;
|
||||||
|
@apply dark:bg-[#7AF0C4] dark:text-gray-950 dark:shadow-emerald-500/20;
|
||||||
|
@apply dark:hover:bg-[#62d9ad];
|
||||||
|
}
|
||||||
|
|
||||||
.btn-alipay {
|
.btn-alipay {
|
||||||
@apply bg-[#00AEEF] text-white shadow-md shadow-[#00AEEF]/25;
|
@apply bg-[#00AEEF] text-white shadow-md shadow-[#00AEEF]/25;
|
||||||
@apply hover:bg-[#009dd6] hover:shadow-lg hover:shadow-[#00AEEF]/30;
|
@apply hover:bg-[#009dd6] hover:shadow-lg hover:shadow-[#00AEEF]/30;
|
||||||
|
|||||||
@ -18,7 +18,7 @@ export type OrderStatus =
|
|||||||
| 'REFUNDED'
|
| 'REFUNDED'
|
||||||
| 'REFUND_FAILED'
|
| 'REFUND_FAILED'
|
||||||
|
|
||||||
export type PaymentType = 'alipay' | 'wxpay' | 'alipay_direct' | 'wxpay_direct' | 'stripe' | 'easypay'
|
export type PaymentType = 'alipay' | 'wxpay' | 'alipay_direct' | 'wxpay_direct' | 'stripe' | 'easypay' | 'airwallex'
|
||||||
|
|
||||||
export type OrderType = 'balance' | 'subscription'
|
export type OrderType = 'balance' | 'subscription'
|
||||||
|
|
||||||
@ -40,6 +40,7 @@ export interface PaymentConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface MethodLimit {
|
export interface MethodLimit {
|
||||||
|
currency?: string
|
||||||
daily_limit: number
|
daily_limit: number
|
||||||
daily_used: number
|
daily_used: number
|
||||||
daily_remaining: number
|
daily_remaining: number
|
||||||
@ -77,6 +78,7 @@ export interface PaymentOrder {
|
|||||||
user_id: number
|
user_id: number
|
||||||
amount: number
|
amount: number
|
||||||
pay_amount: number
|
pay_amount: number
|
||||||
|
currency?: string
|
||||||
fee_rate: number
|
fee_rate: number
|
||||||
payment_type: string
|
payment_type: string
|
||||||
out_trade_no: string
|
out_trade_no: string
|
||||||
@ -187,6 +189,10 @@ export interface CreateOrderResult {
|
|||||||
pay_url?: string
|
pay_url?: string
|
||||||
qr_code?: string
|
qr_code?: string
|
||||||
client_secret?: string
|
client_secret?: string
|
||||||
|
intent_id?: string
|
||||||
|
currency?: string
|
||||||
|
country_code?: string
|
||||||
|
payment_env?: string
|
||||||
pay_amount: number
|
pay_amount: number
|
||||||
fee_rate: number
|
fee_rate: number
|
||||||
expires_at: string
|
expires_at: string
|
||||||
|
|||||||
67
frontend/src/utils/__tests__/ccswitchImport.spec.ts
Normal file
67
frontend/src/utils/__tests__/ccswitchImport.spec.ts
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import {
|
||||||
|
OPENAI_CC_SWITCH_CODEX_MODEL,
|
||||||
|
buildCcSwitchImportDeeplink
|
||||||
|
} from '@/utils/ccswitchImport'
|
||||||
|
import type { GroupPlatform } from '@/types'
|
||||||
|
|
||||||
|
function paramsFromDeeplink(deeplink: string): URLSearchParams {
|
||||||
|
const query = deeplink.split('?')[1] || ''
|
||||||
|
return new URLSearchParams(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ccswitchImport utils', () => {
|
||||||
|
const baseInput = {
|
||||||
|
baseUrl: 'https://api.example.com',
|
||||||
|
providerName: 'Sub2API',
|
||||||
|
apiKey: 'sk-test',
|
||||||
|
usageScript: 'return true'
|
||||||
|
}
|
||||||
|
|
||||||
|
it('adds the Codex model parameter for OpenAI imports', () => {
|
||||||
|
const params = paramsFromDeeplink(
|
||||||
|
buildCcSwitchImportDeeplink({
|
||||||
|
...baseInput,
|
||||||
|
platform: 'openai',
|
||||||
|
clientType: 'claude'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(params.get('resource')).toBe('provider')
|
||||||
|
expect(params.get('app')).toBe('codex')
|
||||||
|
expect(params.get('endpoint')).toBe(baseInput.baseUrl)
|
||||||
|
expect(params.get('model')).toBe(OPENAI_CC_SWITCH_CODEX_MODEL)
|
||||||
|
expect(atob(params.get('usageScript') || '')).toBe(baseInput.usageScript)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.each([
|
||||||
|
{ platform: 'anthropic' as GroupPlatform, clientType: 'claude' as const, app: 'claude' },
|
||||||
|
{ platform: 'gemini' as GroupPlatform, clientType: 'gemini' as const, app: 'gemini' }
|
||||||
|
])('does not add a model parameter for $platform imports', ({ platform, clientType, app }) => {
|
||||||
|
const params = paramsFromDeeplink(
|
||||||
|
buildCcSwitchImportDeeplink({
|
||||||
|
...baseInput,
|
||||||
|
platform,
|
||||||
|
clientType
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(params.get('app')).toBe(app)
|
||||||
|
expect(params.get('endpoint')).toBe(baseInput.baseUrl)
|
||||||
|
expect(params.has('model')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps Antigravity imports on the selected client endpoint without a model parameter', () => {
|
||||||
|
const params = paramsFromDeeplink(
|
||||||
|
buildCcSwitchImportDeeplink({
|
||||||
|
...baseInput,
|
||||||
|
platform: 'antigravity',
|
||||||
|
clientType: 'gemini'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(params.get('app')).toBe('gemini')
|
||||||
|
expect(params.get('endpoint')).toBe(`${baseInput.baseUrl}/antigravity`)
|
||||||
|
expect(params.has('model')).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user