merge: resolve upstream main conflicts for bulk OpenAI passthrough

This commit is contained in:
Wang Lvyuan 2026-03-24 19:27:51 +08:00
commit bb399e56b0
98 changed files with 5168 additions and 213 deletions

View File

@ -61,6 +61,9 @@ temp/
deploy/install.sh deploy/install.sh
deploy/sub2api.service deploy/sub2api.service
deploy/sub2api-sudoers deploy/sub2api-sudoers
deploy/data/
deploy/postgres_data/
deploy/redis_data/
# GoReleaser # GoReleaser
.goreleaser.yaml .goreleaser.yaml

View File

@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient() driveClient := repository.NewGeminiDriveClient()

View File

@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type") accountType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil return nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }

View File

@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
@ -176,6 +177,7 @@ type UpdateSettingsRequest struct {
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@ -231,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 { if req.DefaultBalance < 0 {
req.DefaultBalance = 0 req.DefaultBalance = 0
} }
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
if req.SMTPPort <= 0 { if req.SMTPPort <= 0 {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
req.SMTPHost = previousSettings.SMTPHost
req.SMTPPort = previousSettings.SMTPPort
req.SMTPUsername = previousSettings.SMTPUsername
req.SMTPFrom = previousSettings.SMTPFrom
req.SMTPFromName = previousSettings.SMTPFromName
req.SMTPUseTLS = previousSettings.SMTPUseTLS
}
// Turnstile 参数验证 // Turnstile 参数验证
if req.TurnstileEnabled { if req.TurnstileEnabled {
// 检查必填字段 // 检查必填字段
@ -417,6 +435,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
customMenuJSON = string(menuBytes) customMenuJSON = string(menuBytes)
} }
// 自定义端点验证
const (
maxCustomEndpoints = 10
maxEndpointNameLen = 50
maxEndpointURLLen = 2048
maxEndpointDescriptionLen = 200
)
customEndpointsJSON := previousSettings.CustomEndpoints
if req.CustomEndpoints != nil {
endpoints := *req.CustomEndpoints
if len(endpoints) > maxCustomEndpoints {
response.BadRequest(c, "Too many custom endpoints (max 10)")
return
}
for _, ep := range endpoints {
if strings.TrimSpace(ep.Name) == "" {
response.BadRequest(c, "Custom endpoint name is required")
return
}
if len(ep.Name) > maxEndpointNameLen {
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
return
}
if strings.TrimSpace(ep.Endpoint) == "" {
response.BadRequest(c, "Custom endpoint URL is required")
return
}
if len(ep.Endpoint) > maxEndpointURLLen {
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
return
}
if len(ep.Description) > maxEndpointDescriptionLen {
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
return
}
}
endpointBytes, err := json.Marshal(endpoints)
if err != nil {
response.BadRequest(c, "Failed to serialize custom endpoints")
return
}
customEndpointsJSON = string(endpointBytes)
}
// Ops metrics collector interval validation (seconds). // Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil { if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds v := *req.OpsMetricsIntervalSeconds
@ -495,6 +562,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled, SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON, CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
@ -592,6 +660,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled, SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions, DefaultSubscriptions: updatedDefaultSubscriptions,
@ -828,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
@ -844,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,
@ -877,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求 // SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct { type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
@ -895,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPFrom == "" && savedConfig != nil {
req.SMTPFrom = savedConfig.From
}
if req.SMTPFromName == "" && savedConfig != nil {
req.SMTPFromName = savedConfig.FromName
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,

View File

@ -15,6 +15,13 @@ type CustomMenuItem struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
type CustomEndpoint struct {
Name string `json:"name"`
Endpoint string `json:"endpoint"`
Description string `json:"description"`
}
// SystemSettings represents the admin settings API response payload. // SystemSettings represents the admin settings API response payload.
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
@ -56,6 +63,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
@ -114,6 +122,7 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
} }
return filtered return filtered
} }
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
// Returns empty slice on empty/invalid input.
func ParseCustomEndpoints(raw string) []CustomEndpoint {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return []CustomEndpoint{}
}
var items []CustomEndpoint
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []CustomEndpoint{}
}
return items
}

View File

@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填 // 验证 model 必填
if reqModel == "" { if reqModel == "" {
@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息可能为nil // 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)

View File

@ -0,0 +1,289 @@
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
// POST /v1/chat/completions
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleCCFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.cc.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -0,0 +1,295 @@
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
// POST /v1/responses
// This converts Responses API requests to Anthropic format, forwards to Anthropic
// upstream, and converts responses back to Responses format.
func (h *GatewayHandler) Responses(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream using gjson (like OpenAI handler)
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
// The existing service-layer checkClaudeCodeRestriction handles degradation
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.responsesErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// Can't failover if streaming content already sent
if c.Writer.Size() != writerSizeBeforeForward {
h.handleResponsesFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.responses.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// responsesErrorResponse writes an error in OpenAI Responses API format.
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return // Can't write error after stream started
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)

View File

@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)

View File

@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("previous_response_id_kind", previousResponseIDKind), zap.String("previous_response_id_kind", previousResponseIDKind),
) )
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()

View File

@ -27,6 +27,9 @@ const (
opsRequestBodyKey = "ops_request_body" opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id" opsAccountIDKey = "ops_account_id"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled" opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts" opsErrNoAvailableAccounts = "no available accounts"
@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
} }
} }
// setOpsEndpointContext stores upstream model and request type for ops error logging.
// Called by handlers after model mapping and request type determination.
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
if c == nil {
return
}
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
c.Set(opsUpstreamModelKey, upstreamModel)
}
c.Set(opsRequestTypeKey, requestType)
}
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil { if c == nil || entry == nil {
return return
@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: "upstream", ErrorPhase: "upstream",
@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase, ErrorPhase: phase,

View File

@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}) })
} }
} }
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
v, ok := c.Get(opsUpstreamModelKey)
require.True(t, ok)
vStr, ok := v.(string)
require.True(t, ok)
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(2), rtVal)
}
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "", int16(1))
_, ok := c.Get(opsUpstreamModelKey)
require.False(t, ok, "empty upstream model should not be stored")
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(1), rtVal)
}
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
require.NotPanics(t, func() {
setOpsEndpointContext(nil, "model", int16(1))
})
}

View File

@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,

View File

@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {

View File

@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, clientStream, body) setOpsRequestContext(c, reqModel, clientStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
platform := "" platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {

View File

@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {

View File

@ -0,0 +1,521 @@
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: AnthropicResponse → ResponsesResponse
// ---------------------------------------------------------------------------
// AnthropicToResponsesResponse converts an Anthropic Messages response into a
// Responses API response. This is the reverse of ResponsesToAnthropic and
// enables Anthropic upstream responses to be returned in OpenAI Responses format.
func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
id := resp.ID
if id == "" {
id = generateResponsesID()
}
out := &ResponsesResponse{
ID: id,
Object: "response",
Model: resp.Model,
}
var outputs []ResponsesOutput
var msgParts []ResponsesContentPart
for _, block := range resp.Content {
switch block.Type {
case "thinking":
if block.Thinking != "" {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: block.Thinking,
}},
})
}
case "text":
if block.Text != "" {
msgParts = append(msgParts, ResponsesContentPart{
Type: "output_text",
Text: block.Text,
})
}
case "tool_use":
args := "{}"
if len(block.Input) > 0 {
args = string(block.Input)
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toResponsesCallID(block.ID),
Name: block.Name,
Arguments: args,
Status: "completed",
})
}
}
// Assemble message output item from text parts
if len(msgParts) > 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: msgParts,
Status: "completed",
})
}
if len(outputs) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
Status: "completed",
})
}
out.Output = outputs
// Map stop_reason → status
out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content)
if out.Status == "incomplete" {
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
// Usage
out.Usage = &ResponsesUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.CacheReadInputTokens > 0 {
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: resp.Usage.CacheReadInputTokens,
}
}
return out
}
// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status.
func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string {
switch stopReason {
case "max_tokens":
return "incomplete"
case "end_turn", "tool_use", "stop_sequence":
return "completed"
default:
return "completed"
}
}
// ---------------------------------------------------------------------------
// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
// AnthropicEventToResponsesState tracks state for converting a sequence of
// Anthropic SSE events into Responses SSE events.
type AnthropicEventToResponsesState struct {
ResponseID string
Model string
Created int64
SequenceNumber int
// CreatedSent tracks whether response.created has been emitted.
CreatedSent bool
// CompletedSent tracks whether the terminal event has been emitted.
CompletedSent bool
// Current output tracking
OutputIndex int
CurrentItemID string
CurrentItemType string // "message" | "function_call" | "reasoning"
// For message output: accumulate text parts
ContentIndex int
// For function_call: track per-output info
CurrentCallID string
CurrentName string
// Usage from message_delta
InputTokens int
OutputTokens int
CacheReadInputTokens int
}
// NewAnthropicEventToResponsesState returns an initialised stream state.
func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState {
return &AnthropicEventToResponsesState{
Created: time.Now().Unix(),
}
}
// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into
// zero or more Responses SSE events, updating state as it goes.
func AnthropicEventToResponsesEvents(
evt *AnthropicStreamEvent,
state *AnthropicEventToResponsesState,
) []ResponsesStreamEvent {
switch evt.Type {
case "message_start":
return anthToResHandleMessageStart(evt, state)
case "content_block_start":
return anthToResHandleContentBlockStart(evt, state)
case "content_block_delta":
return anthToResHandleContentBlockDelta(evt, state)
case "content_block_stop":
return anthToResHandleContentBlockStop(evt, state)
case "message_delta":
return anthToResHandleMessageDelta(evt, state)
case "message_stop":
return anthToResHandleMessageStop(state)
default:
return nil
}
}
// FinalizeAnthropicResponsesStream emits synthetic termination events if the
// stream ended without a proper message_stop.
func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if !state.CreatedSent || state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, "completed", nil))
state.CompletedSent = true
return events
}
// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line.
func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) {
data, err := json.Marshal(evt)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
}
// --- internal handlers ---
func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Message != nil {
state.ResponseID = evt.Message.ID
if state.Model == "" {
state.Model = evt.Message.Model
}
if evt.Message.Usage.InputTokens > 0 {
state.InputTokens = evt.Message.Usage.InputTokens
}
}
if state.CreatedSent {
return nil
}
state.CreatedSent = true
// Emit response.created
return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)}
}
func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.ContentBlock == nil {
return nil
}
var events []ResponsesStreamEvent
switch evt.ContentBlock.Type {
case "thinking":
state.CurrentItemID = generateItemID()
state.CurrentItemType = "reasoning"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "reasoning",
ID: state.CurrentItemID,
},
}))
case "text":
// If we don't have an open message item, open one
if state.CurrentItemType != "message" {
state.CurrentItemID = generateItemID()
state.CurrentItemType = "message"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "message",
ID: state.CurrentItemID,
Role: "assistant",
Status: "in_progress",
},
}))
}
case "tool_use":
// Close previous item if any
events = append(events, closeCurrentResponsesItem(state)...)
state.CurrentItemID = generateItemID()
state.CurrentItemType = "function_call"
state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID)
state.CurrentName = evt.ContentBlock.Name
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "function_call",
ID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
Status: "in_progress",
},
}))
}
return events
}
func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Delta == nil {
return nil
}
switch evt.Delta.Type {
case "text_delta":
if evt.Delta.Text == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
Delta: evt.Delta.Text,
ItemID: state.CurrentItemID,
})}
case "thinking_delta":
if evt.Delta.Thinking == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
Delta: evt.Delta.Thinking,
ItemID: state.CurrentItemID,
})}
case "input_json_delta":
if evt.Delta.PartialJSON == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Delta: evt.Delta.PartialJSON,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
})}
case "signature_delta":
// Anthropic signature deltas have no Responses equivalent; skip
return nil
}
return nil
}
func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
switch state.CurrentItemType {
case "reasoning":
// Emit reasoning summary done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
ItemID: state.CurrentItemID,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "function_call":
// Emit function_call_arguments.done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "message":
// Emit output_text.done (text block is done, but message item stays open for potential more blocks)
return []ResponsesStreamEvent{
makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
ItemID: state.CurrentItemID,
}),
}
}
return nil
}
func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
// Update usage
if evt.Usage != nil {
state.OutputTokens = evt.Usage.OutputTokens
if evt.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
}
}
return nil
}
func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Determine status
status := "completed"
var incompleteDetails *ResponsesIncompleteDetails
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails))
state.CompletedSent = true
return events
}
// --- helper functions ---
func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CurrentItemType == "" {
return nil
}
itemType := state.CurrentItemType
itemID := state.CurrentItemID
// Reset
state.CurrentItemType = ""
state.CurrentItemID = ""
state.CurrentCallID = ""
state.CurrentName = ""
state.OutputIndex++
state.ContentIndex = 0
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex - 1, // Use the index before increment
Item: &ResponsesOutput{
Type: itemType,
ID: itemID,
Status: "completed",
},
})}
}
func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
return ResponsesStreamEvent{
Type: "response.created",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: "in_progress",
Output: []ResponsesOutput{},
},
}
}
func makeResponsesCompletedEvent(
state *AnthropicEventToResponsesState,
status string,
incompleteDetails *ResponsesIncompleteDetails,
) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
usage := &ResponsesUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
TotalTokens: state.InputTokens + state.OutputTokens,
}
if state.CacheReadInputTokens > 0 {
usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: state.CacheReadInputTokens,
}
}
return ResponsesStreamEvent{
Type: "response.completed",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: status,
Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity
Usage: usage,
IncompleteDetails: incompleteDetails,
},
}
}
func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
evt := *template
evt.Type = eventType
evt.SequenceNumber = seq
return evt
}
func generateResponsesID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "resp_" + hex.EncodeToString(b)
}
func generateItemID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "item_" + hex.EncodeToString(b)
}

View File

@ -0,0 +1,464 @@
package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// ResponsesToAnthropicRequest converts a Responses API request into an
// Anthropic Messages request. This is the reverse of AnthropicToResponses and
// enables Anthropic platform groups to accept OpenAI Responses API requests
// by converting them to the native /v1/messages format before forwarding upstream.
func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) {
system, messages, err := convertResponsesInputToAnthropic(req.Input)
if err != nil {
return nil, err
}
out := &AnthropicRequest{
Model: req.Model,
Messages: messages,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
}
if len(system) > 0 {
out.System = system
}
// max_output_tokens → max_tokens
if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 {
out.MaxTokens = *req.MaxOutputTokens
}
if out.MaxTokens == 0 {
// Anthropic requires max_tokens; default to a sensible value.
out.MaxTokens = 8192
}
// Convert tools
if len(req.Tools) > 0 {
out.Tools = convertResponsesToAnthropicTools(req.Tools)
}
// Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses)
if len(req.ToolChoice) > 0 {
tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice)
if err != nil {
return nil, fmt.Errorf("convert tool_choice: %w", err)
}
out.ToolChoice = tc
}
// reasoning.effort → output_config.effort + thinking
if req.Reasoning != nil && req.Reasoning.Effort != "" {
effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort)
out.OutputConfig = &AnthropicOutputConfig{Effort: effort}
// Enable thinking for non-low efforts
if effort != "low" {
out.Thinking = &AnthropicThinking{
Type: "enabled",
BudgetTokens: defaultThinkingBudget(effort),
}
}
}
return out, nil
}
// defaultThinkingBudget returns a sensible thinking budget based on effort level.
func defaultThinkingBudget(effort string) int {
switch effort {
case "low":
return 1024
case "medium":
return 4096
case "high":
return 10240
case "max":
return 32768
default:
return 10240
}
}
// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to
// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses.
//
// low → low
// medium → medium
// high → high
// xhigh → max
func mapResponsesEffortToAnthropic(effort string) string {
if effort == "xhigh" {
return "max"
}
return effort // low→low, medium→medium, high→high, unknown→passthrough
}
// convertResponsesInputToAnthropic extracts system prompt and messages from
// a Responses API input array. Returns the system as raw JSON (for Anthropic's
// polymorphic system field) and a list of Anthropic messages.
func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) {
// Try as plain string input.
var inputStr string
if err := json.Unmarshal(inputRaw, &inputStr); err == nil {
content, _ := json.Marshal(inputStr)
return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil
}
var items []ResponsesInputItem
if err := json.Unmarshal(inputRaw, &items); err != nil {
return nil, nil, fmt.Errorf("parse responses input: %w", err)
}
var system json.RawMessage
var messages []AnthropicMessage
for _, item := range items {
switch {
case item.Role == "system":
// System prompt → Anthropic system field
text := extractTextFromContent(item.Content)
if text != "" {
system, _ = json.Marshal(text)
}
case item.Type == "function_call":
// function_call → assistant message with tool_use block
input := json.RawMessage("{}")
if item.Arguments != "" {
input = json.RawMessage(item.Arguments)
}
block := AnthropicContentBlock{
Type: "tool_use",
ID: fromResponsesCallIDToAnthropic(item.CallID),
Name: item.Name,
Input: input,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: blockJSON,
})
case item.Type == "function_call_output":
// function_call_output → user message with tool_result block
outputContent := item.Output
if outputContent == "" {
outputContent = "(empty)"
}
contentJSON, _ := json.Marshal(outputContent)
block := AnthropicContentBlock{
Type: "tool_result",
ToolUseID: fromResponsesCallIDToAnthropic(item.CallID),
Content: contentJSON,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "user",
Content: blockJSON,
})
case item.Role == "user":
content, err := convertResponsesUserToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "user",
Content: content,
})
case item.Role == "assistant":
content, err := convertResponsesAssistantToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: content,
})
default:
// Unknown role/type — attempt as user message
if item.Content != nil {
messages = append(messages, AnthropicMessage{
Role: "user",
Content: item.Content,
})
}
}
}
// Merge consecutive same-role messages (Anthropic requires alternating roles)
messages = mergeConsecutiveMessages(messages)
return system, messages, nil
}
// extractTextFromContent extracts text from a content field that may be a
// plain string or an array of content parts.
func extractTextFromContent(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
var texts []string
for _, p := range parts {
if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" {
texts = append(texts, p.Text)
}
}
return strings.Join(texts, "\n\n")
}
return ""
}
// convertResponsesUserToAnthropicContent converts a Responses user message
// content field into Anthropic content blocks JSON.
func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal("") // empty string content
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
// Pass through as-is if we can't parse
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "input_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
case "input_image":
src := dataURIToAnthropicImageSource(p.ImageURL)
if src != nil {
blocks = append(blocks, AnthropicContentBlock{
Type: "image",
Source: src,
})
}
}
}
if len(blocks) == 0 {
return json.Marshal("")
}
return json.Marshal(blocks)
}
// convertResponsesAssistantToAnthropicContent converts a Responses assistant
// message content field into Anthropic content blocks JSON.
func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}})
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}})
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "output_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
}
}
if len(blocks) == 0 {
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
}
return json.Marshal(blocks)
}
// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to
// Anthropic format. Reverses toResponsesCallID.
func fromResponsesCallIDToAnthropic(id string) string {
// If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it
if after, ok := strings.CutPrefix(id, "fc_"); ok {
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
return after
}
}
// Generate a synthetic Anthropic tool ID
if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") {
return "toolu_" + id
}
return id
}
// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource.
func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource {
if !strings.HasPrefix(dataURI, "data:") {
return nil
}
// Format: data:<media_type>;base64,<data>
rest := strings.TrimPrefix(dataURI, "data:")
semicolonIdx := strings.Index(rest, ";")
if semicolonIdx < 0 {
return nil
}
mediaType := rest[:semicolonIdx]
rest = rest[semicolonIdx+1:]
if !strings.HasPrefix(rest, "base64,") {
return nil
}
data := strings.TrimPrefix(rest, "base64,")
return &AnthropicImageSource{
Type: "base64",
MediaType: mediaType,
Data: data,
}
}
// mergeConsecutiveMessages merges consecutive messages with the same role
// because Anthropic requires alternating user/assistant turns.
func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage {
if len(messages) <= 1 {
return messages
}
var merged []AnthropicMessage
for _, msg := range messages {
if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role {
merged = append(merged, msg)
continue
}
// Same role — merge content arrays
last := &merged[len(merged)-1]
lastBlocks := parseContentBlocks(last.Content)
newBlocks := parseContentBlocks(msg.Content)
combined := append(lastBlocks, newBlocks...)
last.Content, _ = json.Marshal(combined)
}
return merged
}
// parseContentBlocks attempts to parse content as []AnthropicContentBlock.
// If it's a string, wraps it in a text block.
func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock {
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err == nil {
return blocks
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return []AnthropicContentBlock{{Type: "text", Text: s}}
}
return nil
}
// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format.
// Reverse of convertAnthropicToolsToResponses.
func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
})
case "function":
out = append(out, AnthropicTool{
Name: t.Name,
Description: t.Description,
InputSchema: normalizeAnthropicInputSchema(t.Parameters),
})
default:
// Pass through unknown tool types
out = append(out, AnthropicTool{
Type: t.Type,
Name: t.Name,
Description: t.Description,
InputSchema: t.Parameters,
})
}
}
return out
}
// normalizeAnthropicInputSchema ensures the input_schema has a "type" field.
func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
if len(schema) == 0 || string(schema) == "null" {
return json.RawMessage(`{"type":"object","properties":{}}`)
}
return schema
}
// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format.
// Reverse of convertAnthropicToolChoiceToResponses.
//
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
if err := json.Unmarshal(raw, &s); err == nil {
switch s {
case "auto":
return json.Marshal(map[string]string{"type": "auto"})
case "required":
return json.Marshal(map[string]string{"type": "any"})
case "none":
return json.Marshal(map[string]string{"type": "none"})
default:
return raw, nil
}
}
// Try as object with type=function
var tc struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
return json.Marshal(map[string]string{
"type": "tool",
"name": tc.Function.Name,
})
}
// Pass through unknown
return raw, nil
}

View File

@ -270,6 +270,7 @@ type OpenAIAuthClaims struct {
ChatGPTUserID string `json:"chatgpt_user_id"` ChatGPTUserID string `json:"chatgpt_user_id"`
ChatGPTPlanType string `json:"chatgpt_plan_type"` ChatGPTPlanType string `json:"chatgpt_plan_type"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
POID string `json:"poid"` // organization ID in access_token JWT
Organizations []OrganizationClaim `json:"organizations"` Organizations []OrganizationClaim `json:"organizations"`
} }

View File

@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return nil return nil
} }
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
_, err := r.client.Account.UpdateOneID(id).
SetCredentials(normalizeJSONMap(credentials)).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id) groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil { if err != nil {
@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "", 0) return r.ListWithFilters(ctx, params, "", "", "", "", 0, "")
} }
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query() q := r.client.Account.Query()
if platform != "" { if platform != "" {
@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
} else if groupID > 0 { } else if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
} }
if privacyMode != "" {
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
path := sqljson.Path("privacy_mode")
switch privacyMode {
case service.AccountPrivacyModeUnsetFilter:
s.Where(entsql.Or(
entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)),
sqljson.ValueEQ(dbaccount.FieldExtra, "", path),
))
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path))
}
}))
}
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {

View File

@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() { func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct { tests := []struct {
name string name string
setup func(client *dbent.Client) setup func(client *dbent.Client)
platform string platform string
accType string accType string
status string status string
search string search string
groupID int64 groupID int64
wantCount int privacyMode string
validate func(accounts []service.Account) wantCount int
validate func(accounts []service.Account)
}{ }{
{ {
name: "filter_by_platform", name: "filter_by_platform",
@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Empty(accounts[0].GroupIDs) s.Require().Empty(accounts[0].GroupIDs)
}, },
}, },
{
name: "filter_by_privacy_mode",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}})
},
privacyMode: service.PrivacyModeTrainingOff,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("privacy-ok", accounts[0].Name)
},
},
{
name: "filter_by_privacy_mode_unset",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
},
privacyMode: service.AccountPrivacyModeUnsetFilter,
wantCount: 2,
validate: func(accounts []service.Account) {
names := []string{accounts[0].Name, accounts[1].Name}
s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names)
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client) tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID) accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount) s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil { if tt.validate != nil {
@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID) s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)

View File

@ -29,6 +29,11 @@ INSERT INTO ops_error_logs (
model, model,
request_path, request_path,
stream, stream,
inbound_endpoint,
upstream_endpoint,
requested_model,
upstream_model,
request_type,
user_agent, user_agent,
error_phase, error_phase,
error_type, error_type,
@ -57,7 +62,7 @@ INSERT INTO ops_error_logs (
retry_count, retry_count,
created_at created_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43
)` )`
func NewOpsRepository(db *sql.DB) service.OpsRepository { func NewOpsRepository(db *sql.DB) service.OpsRepository {
@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
opsNullString(input.Model), opsNullString(input.Model),
opsNullString(input.RequestPath), opsNullString(input.RequestPath),
input.Stream, input.Stream,
opsNullString(input.InboundEndpoint),
opsNullString(input.UpstreamEndpoint),
opsNullString(input.RequestedModel),
opsNullString(input.UpstreamModel),
opsNullInt16(input.RequestType),
opsNullString(input.UserAgent), opsNullString(input.UserAgent),
input.ErrorPhase, input.ErrorPhase,
input.ErrorType, input.ErrorType,
@ -231,7 +241,12 @@ SELECT
COALESCE(g.name, ''), COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''), COALESCE(e.request_path, ''),
e.stream e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type
FROM ops_error_logs e FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id LEFT JOIN groups g ON e.group_id = g.id
@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
var resolvedBy sql.NullInt64 var resolvedBy sql.NullInt64
var resolvedByName string var resolvedByName string
var resolvedRetryID sql.NullInt64 var resolvedRetryID sql.NullInt64
var requestType sql.NullInt64
if err := rows.Scan( if err := rows.Scan(
&item.ID, &item.ID,
&item.CreatedAt, &item.CreatedAt,
@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
&clientIP, &clientIP,
&item.RequestPath, &item.RequestPath,
&item.Stream, &item.Stream,
&item.InboundEndpoint,
&item.UpstreamEndpoint,
&item.RequestedModel,
&item.UpstreamModel,
&requestType,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
item.GroupID = &v item.GroupID = &v
} }
item.GroupName = groupName item.GroupName = groupName
if requestType.Valid {
v := int16(requestType.Int64)
item.RequestType = &v
}
out = append(out, &item) out = append(out, &item)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -393,6 +418,11 @@ SELECT
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''), COALESCE(e.request_path, ''),
e.stream, e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type,
COALESCE(e.user_agent, ''), COALESCE(e.user_agent, ''),
e.auth_latency_ms, e.auth_latency_ms,
e.routing_latency_ms, e.routing_latency_ms,
@ -427,6 +457,7 @@ LIMIT 1`
var responseLatency sql.NullInt64 var responseLatency sql.NullInt64
var ttft sql.NullInt64 var ttft sql.NullInt64
var requestBodyBytes sql.NullInt64 var requestBodyBytes sql.NullInt64
var requestType sql.NullInt64
err := r.db.QueryRowContext(ctx, q, id).Scan( err := r.db.QueryRowContext(ctx, q, id).Scan(
&out.ID, &out.ID,
@ -464,6 +495,11 @@ LIMIT 1`
&clientIP, &clientIP,
&out.RequestPath, &out.RequestPath,
&out.Stream, &out.Stream,
&out.InboundEndpoint,
&out.UpstreamEndpoint,
&out.RequestedModel,
&out.UpstreamModel,
&requestType,
&out.UserAgent, &out.UserAgent,
&authLatency, &authLatency,
&routingLatency, &routingLatency,
@ -540,6 +576,10 @@ LIMIT 1`
v := int(requestBodyBytes.Int64) v := int(requestBodyBytes.Int64)
out.RequestBodyBytes = &v out.RequestBodyBytes = &v
} }
if requestType.Valid {
v := int16(requestType.Int64)
out.RequestType = &v
}
// Normalize request_body to empty string when stored as JSON null. // Normalize request_body to empty string when stored as JSON null.
out.RequestBody = strings.TrimSpace(out.RequestBody) out.RequestBody = strings.TrimSpace(out.RequestBody)
@ -1479,3 +1519,10 @@ func opsNullInt(v any) any {
return sql.NullInt64{} return sql.NullInt64{}
} }
} }
func opsNullInt16(v *int16) any {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: int64(*v), Valid: true}
}

View File

@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) {
"max_claude_code_version": "", "max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false, "backend_mode_enabled": false,
"custom_menu_items": [] "custom_menu_items": [],
"custom_endpoints": []
} }
}`, }`,
}, },
@ -989,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }

View File

@ -69,12 +69,30 @@ func RegisterGatewayRoutes(
}) })
gateway.GET("/models", h.Gateway.Models) gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API: auto-route based on group platform
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", func(c *gin.Context) {
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.POST("/responses/*subpath", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API // OpenAI Chat Completions API: auto-route based on group platform
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) gateway.POST("/chat/completions", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
} }
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
@ -92,12 +110,25 @@ func RegisterGatewayRoutes(
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// OpenAI Responses API不带v1前缀的别名 // OpenAI Responses API不带v1前缀的别名— auto-route based on group platform
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) responsesHandler := func(c *gin.Context) {
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
}
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名 // OpenAI Chat Completions API不带v1前缀的别名— auto-route based on group platform
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)

View File

@ -0,0 +1,30 @@
package service
import "context"
type accountCredentialsUpdater interface {
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
}
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
if repo == nil || account == nil {
return nil
}
account.Credentials = cloneCredentials(credentials)
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
}
return repo.Update(ctx, account)
}
func cloneCredentials(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@ -15,6 +15,7 @@ var (
) )
const AccountListGroupUngrouped int64 = -1 const AccountListGroupUngrouped int64 = -1
const AccountPrivacyModeUnsetFilter = "__unset__"
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *Account) error Create(ctx context.Context, account *Account) error
@ -37,7 +38,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)

View File

@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call") panic("unexpected List call")
} }
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }

View File

@ -54,7 +54,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }

View File

@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct {
listWithFiltersType string listWithFiltersType string
listWithFiltersStatus string listWithFiltersStatus string
listWithFiltersSearch string listWithFiltersSearch string
listWithFiltersPrivacy string
listWithFiltersAccounts []Account listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error listWithFiltersErr error
} }
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++ s.listWithFiltersCalls++
s.listWithFiltersParams = params s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType s.listWithFiltersType = accountType
s.listWithFiltersStatus = status s.listWithFiltersStatus = status
s.listWithFiltersSearch = search s.listWithFiltersSearch = search
s.listWithFiltersPrivacy = privacyMode
if s.listWithFiltersErr != nil { if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr return nil, nil, s.listWithFiltersErr
@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}) })
} }
func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) { func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{ repo := &proxyRepoStubForAdminList{

View File

@ -643,6 +643,7 @@ urlFallbackLoop:
AccountID: p.account.ID, AccountID: p.account.ID,
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
@ -720,6 +721,7 @@ urlFallbackLoop:
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),
@ -754,6 +756,7 @@ urlFallbackLoop:
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),

View File

@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p.markBackfillAttempted(account.ID) p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed", slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID, "account_id", account.ID,
"error", updateErr, "error", updateErr,

View File

@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
} }
item.Action = "created" item.Action = "created"
@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
} }
@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"
@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"

View File

@ -119,6 +119,7 @@ const (
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@ -12,6 +12,7 @@ import (
"net/smtp" "net/smtp"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[SettingKeySMTPHost] host := strings.TrimSpace(settings[SettingKeySMTPHost])
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return &SMTPConfig{ return &SMTPConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[SettingKeySMTPUsername], Username: strings.TrimSpace(settings[SettingKeySMTPUsername]),
Password: settings[SettingKeySMTPPassword], Password: strings.TrimSpace(settings[SettingKeySMTPPassword]),
From: settings[SettingKeySMTPFrom], From: strings.TrimSpace(settings[SettingKeySMTPFrom]),
FromName: settings[SettingKeySMTPFromName], FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]),
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }

View File

@ -0,0 +1,485 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body,
// converts it to Anthropic Messages format (chained via Responses format),
// forwards to the Anthropic upstream, and converts the response back to Chat
// Completions format. This enables Chat Completions clients to access Anthropic
// models through Anthropic platform groups.
func (s *GatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var ccReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &ccReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := ccReq.Model
clientStream := ccReq.Stream
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
// 2. Convert CC → Responses → Anthropic (chained conversion)
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Extract reasoning effort from CC request body
reasoningEffort := extractCCReasoningEffortFromBody(body)
// 14. Handle normal response
// Read Anthropic SSE → convert to Responses events → convert to CC format
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage)
} else {
result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions
// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort)
// formats used by OpenAI-compatible clients.
func extractCCReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full
// response, then converts Anthropic → Responses → Chat Completions.
func (s *GatewayService) handleCCBufferedFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
// message_start carries the initial response structure and cache usage
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Chain: Anthropic → Responses → Chat Completions
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, ccResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each
// to Responses events, then to Chat Completions chunks, and writes them.
func (s *GatewayService) handleCCStreamingFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
includeUsage bool,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
// Use Anthropic→Responses state machine, then convert Responses→CC
anthState := apicompat.NewAnthropicEventToResponsesState()
anthState.Model = originalModel
ccState := apicompat.NewResponsesEventToChatState()
ccState.Model = originalModel
ccState.IncludeUsage = includeUsage
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
return true // client disconnected
}
return false
}
processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start (carries cache fields)
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Chain: Anthropic event → Responses events → CC chunks
responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState)
for _, resEvt := range responsesEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
if disconnected := writeChunk(chunk); disconnected {
return true
}
}
}
c.Writer.Flush()
return false
}
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
if processAnthropicEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Finalize both state machines
finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState)
for _, resEvt := range finalResEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
writeChunk(chunk) //nolint:errcheck
}
}
finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState)
for _, chunk := range finalCCChunks {
writeChunk(chunk) //nolint:errcheck
}
// Write [DONE] marker
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
// writeGatewayCCError writes an error in OpenAI Chat Completions format for
// the Anthropic-upstream CC forwarding path.
func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@ -0,0 +1,109 @@
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractCCReasoningEffortFromBody(t *testing.T) {
t.Parallel()
t.Run("nested reasoning.effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
})
t.Run("flat reasoning_effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`))
require.NotNil(t, got)
require.Equal(t, "xhigh", *got)
})
t.Run("missing effort", func(t *testing.T) {
require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`)))
})
}
func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "high"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "high", *result.ReasoningEffort)
}
func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "medium"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "medium", *result.ReasoningEffort)
require.Contains(t, rec.Body.String(), `[DONE]`)
}

View File

@ -0,0 +1,518 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsResponses accepts an OpenAI Responses API request body, converts it
// to Anthropic Messages format, forwards to the Anthropic upstream, and converts
// the response back to Responses format. This enables OpenAI Responses API
// clients to access Anthropic models through Anthropic platform groups.
//
// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic
// but in reverse direction: Responses → Anthropic upstream → Responses.
func (s *GatewayService) ForwardAsResponses(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Responses request
var responsesReq apicompat.ResponsesRequest
if err := json.Unmarshal(body, &responsesReq); err != nil {
return nil, fmt.Errorf("parse responses request: %w", err)
}
originalModel := responsesReq.Model
clientStream := responsesReq.Stream
// 2. Convert Responses → Anthropic
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming (Anthropic works best with streaming)
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_responses: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
// Non-failover error: return Responses-formatted error to client
writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Handle normal response (convert Anthropic → Responses)
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
} else {
result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort
// and normalizes it for usage logging.
func ExtractResponsesReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) {
if dst == nil {
return
}
if src.InputTokens > 0 {
dst.InputTokens = src.InputTokens
}
if src.OutputTokens > 0 {
dst.OutputTokens = src.OutputTokens
}
if src.CacheReadInputTokens > 0 {
dst.CacheReadInputTokens = src.CacheReadInputTokens
}
if src.CacheCreationInputTokens > 0 {
dst.CacheCreationInputTokens = src.CacheCreationInputTokens
}
}
// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from
// the upstream streaming response, assembles them into a complete Anthropic
// response, converts to Responses API JSON format, and writes it to the client.
func (s *GatewayService) handleResponsesBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// Accumulate the final Anthropic response from streaming events
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read the data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
// message_start carries the initial response structure
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
// Accumulate content blocks
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Convert to Responses format
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
responsesResp.Model = originalModel // Use original model name
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleResponsesStreamingResponse reads Anthropic SSE events from upstream,
// converts each to Responses SSE events, and writes them to the client.
func (s *GatewayService) handleResponsesStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewAnthropicEventToResponsesState()
state.Model = originalModel
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
// processEvent handles a single parsed Anthropic SSE event.
processEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Convert to Responses events
events := apicompat.AnthropicEventToResponsesEvents(event, state)
for _, evt := range events {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
logger.L().Warn("forward_as_responses stream: failed to marshal event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
return true // client disconnected
}
}
if len(events) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*ForwardResult, error) {
if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// Read Anthropic SSE events
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
if processEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
return finalizeStream()
}
// appendRawJSON appends a JSON fragment string to existing raw JSON.
func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage {
if len(existing) == 0 {
return json.RawMessage(fragment)
}
return json.RawMessage(string(existing) + fragment)
}
// writeResponsesError writes an error response in OpenAI Responses API format.
func writeResponsesError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes.
func mapUpstreamStatusCode(code int) int {
if code >= 500 {
return http.StatusBadGateway
}
return code
}

View File

@ -0,0 +1,94 @@
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractResponsesReasoningEffortFromBody(t *testing.T) {
t.Parallel()
got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`)))
}
func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `"cached_tokens":9`)
}
func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `response.completed`)
}

View File

@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@ -5,6 +5,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
"regexp"
"sort"
"strings" "strings"
"unsafe" "unsafe"
@ -34,6 +36,9 @@ var (
patternEmptyTextSpaced = []byte(`"text": ""`) patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`) patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`) patternEmptyTextSp2 = []byte(`"text" :""`)
sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`)
sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`)
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。 // SessionContext 粘性会话上下文,用于区分不同来源的请求。
@ -75,6 +80,49 @@ type ParsedRequest struct {
OnUpstreamAccepted func() OnUpstreamAccepted func()
} }
// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing.
// It preserves the set of product names from Product/Version tokens while
// discarding version-only changes and incidental comments.
func NormalizeSessionUserAgent(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1)
if len(matches) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
products := make([]string, 0, len(matches))
seen := make(map[string]struct{}, len(matches))
for _, match := range matches {
if len(match) < 2 {
continue
}
product := strings.ToLower(strings.TrimSpace(match[1]))
if product == "" {
continue
}
if _, exists := seen[product]; exists {
continue
}
seen[product] = struct{}{}
products = append(products, product)
}
if len(products) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
sort.Strings(products)
return strings.Join(products, "+")
}
func normalizeSessionUserAgentFallback(raw string) string {
normalized := strings.ToLower(strings.Join(strings.Fields(raw), " "))
normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "")
return strings.Join(strings.Fields(normalized), " ")
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果。 // ParseGatewayRequest 解析网关请求体并返回结构化结果。
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini // protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。 // 不同协议使用不同的 system/messages 字段名。
@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte {
return []byte(r.Raw) return []byte(r.Raw)
} }
// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content).
// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged.
func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) {
var result []any
changed := false
for i, block := range blocks {
blockMap, ok := block.(map[string]any)
if !ok {
if result != nil {
result = append(result, block)
}
continue
}
blockType, _ := blockMap["type"].(string)
// Strip empty text blocks
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
continue
}
}
// Recurse into tool_result nested content
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
result = append(result, blockCopy)
continue
}
}
}
if result != nil {
result = append(result, block)
}
}
if !changed {
return blocks, false
}
return result, true
}
// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content).
// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors.
// Returns the original body unchanged if no empty text blocks are found.
func StripEmptyTextBlocks(body []byte) []byte {
// Fast path: check if body contains empty text patterns
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
if !hasEmptyTextBlock {
return body
}
jsonStr := *(*string)(unsafe.Pointer(&body))
msgsRes := gjson.Get(jsonStr, "messages")
if !msgsRes.Exists() || !msgsRes.IsArray() {
return body
}
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil {
return body
}
modified := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed {
modified = true
msgMap["content"] = cleaned
}
}
if !modified {
return body
}
msgsBytes, err := json.Marshal(messages)
if err != nil {
return body
}
out, err := sjson.SetRawBytes(body, "messages", msgsBytes)
if err != nil {
return body
}
return out
}
// FilterThinkingBlocks removes thinking blocks from request body // FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe) // Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures // This prevents 400 errors from invalid thinking block signatures
@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
} }
} }
// Recursively strip empty text blocks from tool_result nested content.
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed {
modifiedThisMsg = true
ensureNewContent(bi)
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
newContent = append(newContent, blockCopy)
continue
}
}
}
if newContent != nil { if newContent != nil {
newContent = append(newContent, block) newContent = append(newContent, block)
} }

View File

@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
require.NotEmpty(t, block1["text"]) require.NotEmpty(t, block1["text"])
} }
func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) {
// Empty text blocks nested inside tool_result content should also be stripped
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":"valid result"},
{"type":"text","text":""}
]}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
toolResult := content0[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"])
}
func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) {
// If all nested content blocks in tool_result are empty text, content becomes empty slice
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":""}
]},
{"type":"text","text":"hello"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 2)
toolResult := content0[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 0)
}
func TestStripEmptyTextBlocks(t *testing.T) {
t.Run("strips top-level empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
require.Equal(t, "hello", content[0].(map[string]any)["text"])
})
t.Run("strips nested empty text in tool_result", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
toolResult := content[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"])
})
t.Run("no-op when no empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := StripEmptyTextBlocks(input)
require.Equal(t, input, out)
})
t.Run("preserves non-map blocks in content", func(t *testing.T) {
// tool_result content can be a string; non-map blocks should pass through unchanged
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
toolResult := content[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
require.Equal(t, "string content", toolResult["content"])
})
t.Run("handles deeply nested tool_result", func(t *testing.T) {
// Recursive: tool_result containing another tool_result with empty text
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
outer := content[0].(map[string]any)
innerContent := outer["content"].([]any)
inner := innerContent[0].(map[string]any)
deepContent := inner["content"].([]any)
require.Len(t, deepContent, 1)
require.Equal(t, "deep", deepContent[0].(map[string]any)["text"])
})
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) { func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged // Non-empty text blocks should pass through unchanged
input := []byte(`{ input := []byte(`{

View File

@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed.SessionContext != nil { if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP) _, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":") _, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent) _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent))
_, _ = combined.WriteString(":") _, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|") _, _ = combined.WriteString("|")
@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 调试日志:记录即将转发的账号信息 // 调试日志:记录即将转发的账号信息
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, body)
@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "signature_error", Kind: "signature_error",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode, UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"), UpstreamRequestID: retryResp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(retryReq.URL.String()),
Kind: "signature_retry_thinking", Kind: "signature_retry_thinking",
Message: extractUpstreamErrorMessage(retryRespBody), Message: extractUpstreamErrorMessage(retryRespBody),
Detail: func() string { Detail: func() string {
@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(retryReq2.URL.String()),
Kind: "signature_retry_tools_request_error", Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
}) })
@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "budget_constraint_error", Kind: "budget_constraint_error",
Message: errMsg, Message: errMsg,
Detail: func() string { Detail: func() string {
@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
if c != nil { if c != nil {
c.Set("anthropic_passthrough", true) c.Set("anthropic_passthrough", true)
} }
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
input.Body = StripEmptyTextBlocks(input.Body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, input.Body) setOpsUpstreamRequestBody(c, input.Body)
@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
Detail: func() string { Detail: func() string {
@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "request_error", Kind: "request_error",
Message: sanitizeUpstreamErrorMessage(err.Error()), Message: sanitizeUpstreamErrorMessage(err.Error()),
@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true, Passthrough: true,
Kind: "http_error", Kind: "http_error",
Message: upstreamMsg, Message: upstreamMsg,

View File

@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
// 返回 16 字符的 Base64 编码的 SHA256 前缀 // 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符 // 组合所有标识符
normalizedUserAgent := NormalizeSessionUserAgent(userAgent)
combined := strconv.FormatInt(userID, 10) + ":" + combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" + strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" + ip + ":" +
userAgent + ":" + normalizedUserAgent + ":" +
platform + ":" + platform + ":" +
model model

View File

@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
} }
} }
func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestParseGeminiSessionValue(t *testing.T) { func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if tierID != "" { if tierID != "" {
account.Credentials["tier_id"] = tierID account.Credentials["tier_id"] = tierID
} }
_ = p.accountRepo.Update(ctx, account) _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
} }
} }

View File

@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
} }
func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0"))
h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1"))
require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0"))
h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1"))
require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}

View File

@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB // 5. 设置版本号 + 更新 DB
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
freshAccount.Credentials = newCredentials if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
slog.Error("oauth_refresh_update_failed", slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID, "account_id", freshAccount.ID,
"error", updateErr, "error", updateErr,

View File

@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct { type refreshAPIAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
account *Account // returned by GetByID account *Account // returned by GetByID
getByIDErr error getByIDErr error
updateErr error updateErr error
updateCalls int updateCalls int
updateCredentialsCalls int
} }
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return r.updateErr return r.updateErr
} }
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
if r.account == nil || r.account.ID != id {
r.account = &Account{ID: id}
}
r.account.Credentials = cloneCredentials(credentials)
return nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct { type refreshAPIExecutorStub struct {
needsRefresh bool needsRefresh bool
@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls) require.Equal(t, 1, executor.refreshCalls)
} }
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
resetAt := time.Now().Add(45 * time.Minute)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "safe-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotNil(t, repo.account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) { func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic} account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account} repo := &refreshAPIAccountRepo{account: account}
@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant") require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
} }
@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result := MergeCredentials(old, new) result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
} }
// ========== BuildClaudeAccountCredentials tests ========== // ========== BuildClaudeAccountCredentials tests ==========

View File

@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil return nil, nil
} }
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {
@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue continue
} }
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil { if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr return nil, len(candidates), topK, loadSkew, acquireErr

View File

@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID) require.Equal(t, int64(32002), account.ID)
} }
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(33002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10104)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(34002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(9) groupID := int64(9)

View File

@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil return nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
if fresh == nil {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) { (requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if err == nil && result.Acquired { if account == nil {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return &AccountSelectionResult{ } else {
Account: account, result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
Acquired: true, if err == nil && result.Acquired {
ReleaseFunc: result.ReleaseFunc, _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
}, nil return &AccountSelectionResult{
} Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
AccountID: accountID, AccountID: accountID,
MaxConcurrency: account.Concurrency, MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout, Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting, MaxWaiting: cfg.StickySessionMaxWaiting,
}, },
}, nil }, nil
}
} }
} }
} }
@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh return fresh
} }
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
return account
}
latest, err := s.accountRepo.GetByID(ctx, account.ID)
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil
}
return latest
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
var ( var (
account *Account account *Account
@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
} }
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,

View File

@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require.True(t, arr[len(arr)-1].Passthrough) require.True(t, arr[len(arr)-1].Passthrough)
} }
func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-rate-limit"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))),
}
upstream := &httpUpstreamRecorder{resp: resp}
repo := &openAIWSRateLimitSignalRepo{}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
rateLimitService: rateSvc,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Contains(t, rec.Body.String(), "usage_limit_reached")
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
}
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -29,9 +29,10 @@ type soraSessionChunk struct {
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct { type OpenAIOAuthService struct {
sessionStore *openai.SessionStore sessionStore *openai.SessionStore
proxyRepo ProxyRepository proxyRepo ProxyRepository
oauthClient OpenAIOAuthClient oauthClient OpenAIOAuthClient
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-apiImpersonateChrome
} }
// NewOpenAIOAuthService creates a new OpenAI OAuth service // NewOpenAIOAuthService creates a new OpenAI OAuth service
@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli
} }
} }
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
// 用于调用 chatgpt.com/backend-api 获取账号信息plan_type 等)。
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
s.privacyClientFactory = factory
}
// OpenAIAuthURLResult contains the authorization URL and session info // OpenAIAuthURLResult contains the authorization URL and session info
type OpenAIAuthURLResult struct { type OpenAIAuthURLResult struct {
AuthURL string `json:"auth_url"` AuthURL string `json:"auth_url"`
@ -131,6 +138,7 @@ type OpenAITokenInfo struct {
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"` PlanType string `json:"plan_type,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"`
} }
// ExchangeCode exchanges authorization code for tokens // ExchangeCode exchanges authorization code for tokens
@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.PlanType = userInfo.PlanType tokenInfo.PlanType = userInfo.PlanType
} }
// id_token 中缺少 plan_type 时(如 Mobile RT尝试通过 ChatGPT backend-api 补全
if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
// 从 access_token JWT 中提取 orgIDpoid用于匹配正确的账号
orgID := tokenInfo.OrganizationID
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
orgID = atClaims.OpenAIAuth.POID
}
}
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
if tokenInfo.PlanType == "" && info.PlanType != "" {
tokenInfo.PlanType = info.PlanType
}
if tokenInfo.Email == "" && info.Email != "" {
tokenInfo.Email = info.Email
}
}
}
// 尝试设置隐私关闭训练数据共享best-effort
if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
return tokenInfo, nil return tokenInfo, nil
} }

View File

@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
return PrivacyModeTrainingOff return PrivacyModeTrainingOff
} }
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
type ChatGPTAccountInfo struct {
PlanType string
Email string
}
const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
// Returns nil on any failure (best-effort, non-blocking).
func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo {
if accessToken == "" || clientFactory == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
client, err := clientFactory(proxyURL)
if err != nil {
slog.Debug("chatgpt_account_check_client_error", "error", err.Error())
return nil
}
var result map[string]any
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Origin", "https://chatgpt.com").
SetHeader("Referer", "https://chatgpt.com/").
SetHeader("Accept", "application/json").
SetSuccessResult(&result).
Get(chatGPTAccountsCheckURL)
if err != nil {
slog.Debug("chatgpt_account_check_request_error", "error", err.Error())
return nil
}
if !resp.IsSuccessState() {
slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
return nil
}
info := &ChatGPTAccountInfo{}
accounts, ok := result["accounts"].(map[string]any)
if !ok {
slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300))
return nil
}
// 优先匹配 orgID 对应的账号access_token JWT 中的 poid
if orgID != "" {
if matched := extractPlanFromAccount(accounts, orgID); matched != "" {
info.PlanType = matched
}
}
// 未匹配到时,遍历所有账号:优先 is_default次选非 free
if info.PlanType == "" {
var defaultPlan, paidPlan, anyPlan string
for _, acctRaw := range accounts {
acct, ok := acctRaw.(map[string]any)
if !ok {
continue
}
planType := extractPlanType(acct)
if planType == "" {
continue
}
if anyPlan == "" {
anyPlan = planType
}
if account, ok := acct["account"].(map[string]any); ok {
if isDefault, _ := account["is_default"].(bool); isDefault {
defaultPlan = planType
}
}
if !strings.EqualFold(planType, "free") && paidPlan == "" {
paidPlan = planType
}
}
// 优先级default > 非 free > 任意
switch {
case defaultPlan != "":
info.PlanType = defaultPlan
case paidPlan != "":
info.PlanType = paidPlan
default:
info.PlanType = anyPlan
}
}
if info.PlanType == "" {
slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300))
return nil
}
slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID)
return info
}
// extractPlanFromAccount 从 accounts map 中按 keyaccount_id精确匹配并提取 plan_type
func extractPlanFromAccount(accounts map[string]any, accountKey string) string {
acctRaw, ok := accounts[accountKey]
if !ok {
return ""
}
acct, ok := acctRaw.(map[string]any)
if !ok {
return ""
}
return extractPlanType(acct)
}
// extractPlanType 从单个 account 对象中提取 plan_type
func extractPlanType(acct map[string]any) string {
if account, ok := acct["account"].(map[string]any); ok {
if planType, ok := account["plan_type"].(string); ok && planType != "" {
return planType
}
}
if entitlement, ok := acct["entitlement"].(map[string]any); ok {
if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" {
return subPlan
}
}
return ""
}
func truncate(s string, n int) string { func truncate(s string, n int) string {
if len(s) <= n { if len(s) <= n {
return s return s

View File

@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.Zero(t, boundAccountID) require.Zero(t, boundAccountID)
} }
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(24)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleAccount := &Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
dbAccount := Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
RateLimitResetAt: &rateLimitedUntil,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
snapshotCache := &openAISnapshotCacheStub{
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
require.NoError(t, getErr)
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(23) groupID := int64(23)

View File

@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil return nil, nil
} }
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired { if acquireErr == nil && result.Acquired {

View File

@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
return nil return nil
} }
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
_ = platform _ = platform
_ = accountType _ = accountType
_ = status _ = status
_ = search _ = search
_ = groupID _ = groupID
_ = privacyMode
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
} }
@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Len(t, accounts, 1) require.Len(t, accounts, 1)

View File

@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page, Page: page,
PageSize: opsAccountsPageSize, PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "", 0) }, platformFilter, "", "", "", 0, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -62,6 +62,12 @@ type OpsErrorLog struct {
ClientIP *string `json:"client_ip"` ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"` RequestPath string `json:"request_path"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
InboundEndpoint string `json:"inbound_endpoint"`
UpstreamEndpoint string `json:"upstream_endpoint"`
RequestedModel string `json:"requested_model"`
UpstreamModel string `json:"upstream_model"`
RequestType *int16 `json:"request_type"`
} }
type OpsErrorLogDetail struct { type OpsErrorLogDetail struct {

View File

@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
Model string Model string
RequestPath string RequestPath string
Stream bool Stream bool
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint string
// RequestedModel is the client-requested model name before mapping.
RequestedModel string
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
UpstreamModel string
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
// Matches service.RequestType enum semantics from usage_log.go.
RequestType *int16
UserAgent string UserAgent string
ErrorPhase string ErrorPhase string

View File

@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed). // Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt. // Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"` UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind) ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
ev.Message = strings.TrimSpace(ev.Message) ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail) ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" { if ev.Message != "" {
@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
} }
return out, nil return out, nil
} }
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
func safeUpstreamURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
rawURL = rawURL[:idx]
}
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
rawURL = rawURL[:idx]
}
return rawURL
}

View File

@ -8,6 +8,27 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestSafeUpstreamURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
{"no query or fragment", "https://host/path", "https://host/path"},
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
})
}
}
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil { if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else { } else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)

View File

@ -15,9 +15,11 @@ import (
type rateLimitAccountRepoStub struct { type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini mockAccountRepoForGemini
setErrorCalls int setErrorCalls int
tempCalls int tempCalls int
lastErrorMsg string updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
} }
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil return nil
} }
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct { type tokenCacheInvalidatorRecorder struct {
accounts []*Account accounts []*Account
err error err error
@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
} }
@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts) require.Empty(t, invalidator.accounts)
} }
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}

View File

@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) { func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {

View File

@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPurchaseSubscriptionURL, SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled, SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems, SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
} }
@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil }, nil
@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version, Version: s.version,
@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result return result
} }
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
if raw == "" {
return json.RawMessage("[]")
}
if json.Valid([]byte(raw)) {
return json.RawMessage(raw)
}
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. // and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyPurchaseSubscriptionURL: "", SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false", SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]", SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultSubscriptions: "[]",
@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
} }

View File

@ -43,6 +43,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
SoraClientEnabled bool SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
@ -104,6 +105,7 @@ type PublicSettings struct {
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
SoraClientEnabled bool SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool

View File

@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
} }
if c.accountRepo != nil { if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
} }
} }

View File

@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials, err = refresher.Refresh(ctx, account) newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil { if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli() newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr) return fmt.Errorf("failed to save credentials: %w", saveErr)
} }
} }

View File

@ -14,19 +14,40 @@ import (
type tokenRefreshAccountRepo struct { type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
updateCalls int updateCalls int
setErrorCalls int fullUpdateCalls int
clearTempCalls int updateCredentialsCalls int
lastAccount *Account setErrorCalls int
updateErr error clearTempCalls int
lastAccount *Account
updateErr error
} }
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++ r.updateCalls++
r.fullUpdateCalls++
r.lastAccount = account r.lastAccount = account
return r.updateErr return r.updateErr
} }
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
cloned := cloneCredentials(credentials)
if r.accountsByID != nil {
if acc, ok := r.accountsByID[id]; ok && acc != nil {
acc.Credentials = cloned
r.lastAccount = acc
return nil
}
}
r.lastAccount = &Account{ID: id, Credentials: cloned}
return nil
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++ r.setErrorCalls++
return nil return nil
@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.Equal(t, 1, invalidator.calls) require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token")) require.Equal(t, "new-token", account.GetCredential("access_token"))
} }
@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
} }
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
Credentials: map[string]any{
"access_token": "old-token",
},
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.NotNil(t, account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
} }

View File

@ -0,0 +1,28 @@
-- Ops error logs: add endpoint, model mapping, and request_type fields
-- to match usage_logs observability coverage.
--
-- All columns are nullable with no default to preserve backward compatibility
-- with existing rows.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256),
ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256);
-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100),
ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS request_type SMALLINT;
COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.';
COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.';
COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).';
COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.';
COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.';

View File

@ -36,6 +36,7 @@ export async function list(
status?: string status?: string
group?: string group?: string
search?: string search?: string
privacy_mode?: string
lite?: string lite?: string
}, },
options?: { options?: {
@ -68,6 +69,7 @@ export async function listWithEtag(
status?: string status?: string
group?: string group?: string
search?: string search?: string
privacy_mode?: string
lite?: string lite?: string
}, },
options?: { options?: {
@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
export async function refreshOpenAIToken( export async function refreshOpenAIToken(
refreshToken: string, refreshToken: string,
proxyId?: number | null, proxyId?: number | null,
endpoint: string = '/admin/openai/refresh-token' endpoint: string = '/admin/openai/refresh-token',
clientId?: string
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = { const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = {
refresh_token: refreshToken refresh_token: refreshToken
} }
if (proxyId) { if (proxyId) {
payload.proxy_id = proxyId payload.proxy_id = proxyId
} }
if (clientId) {
payload.client_id = clientId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload) const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data return data
} }

View File

@ -969,6 +969,13 @@ export interface OpsErrorLog {
client_ip?: string | null client_ip?: string | null
request_path?: string request_path?: string
stream?: boolean stream?: boolean
// Error observability context (endpoint + model mapping)
inbound_endpoint?: string
upstream_endpoint?: string
requested_model?: string
upstream_model?: string
request_type?: number | null
} }
export interface OpsErrorDetail extends OpsErrorLog { export interface OpsErrorDetail extends OpsErrorLog {

View File

@ -4,7 +4,7 @@
*/ */
import { apiClient } from '../client' import { apiClient } from '../client'
import type { CustomMenuItem } from '@/types' import type { CustomMenuItem, CustomEndpoint } from '@/types'
export interface DefaultSubscriptionSetting { export interface DefaultSubscriptionSetting {
group_id: number group_id: number
@ -43,6 +43,7 @@ export interface SystemSettings {
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
// SMTP settings // SMTP settings
smtp_host: string smtp_host: string
smtp_port: number smtp_port: number
@ -112,6 +113,7 @@ export interface UpdateSettingsRequest {
sora_client_enabled?: boolean sora_client_enabled?: boolean
backend_mode_enabled?: boolean backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[]
smtp_host?: string smtp_host?: string
smtp_port?: number smtp_port?: number
smtp_username?: string smtp_username?: string

View File

@ -661,6 +661,43 @@
</div> </div>
</div> </div>
<!-- OpenAI OAuth WS mode -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-ws-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-ws-mode-enabled"
>
{{ t('admin.accounts.openai.wsMode') }}
</label>
<input
v-model="enableOpenAIWSMode"
id="bulk-edit-openai-ws-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-ws-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-ws-mode"
:class="!enableOpenAIWSMode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
<Select
v-model="openaiOAuthResponsesWebSocketV2Mode"
data-testid="bulk-edit-openai-ws-mode-select"
:options="openAIWSModeOptions"
aria-labelledby="bulk-edit-openai-ws-mode-label"
/>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) --> <!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
@ -883,6 +920,13 @@ import {
buildModelMappingObject as buildModelMappingPayload, buildModelMappingObject as buildModelMappingPayload,
getPresetMappingsByPlatform getPresetMappingsByPlatform
} from '@/composables/useModelWhitelist' } from '@/composables/useModelWhitelist'
import {
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey
} from '@/utils/openaiWsMode'
import type { OpenAIWSMode } from '@/utils/openaiWsMode'
interface Props { interface Props {
show: boolean show: boolean
accountIds: number[] accountIds: number[]
@ -913,6 +957,15 @@ const allOpenAIPassthroughCapable = computed(() => {
) )
}) })
const allOpenAIOAuth = computed(() => {
return (
props.selectedPlatforms.length === 1 &&
props.selectedPlatforms[0] === 'openai' &&
props.selectedTypes.length > 0 &&
props.selectedTypes.every(t => t === 'oauth')
)
})
// Anthropic OAuth/SetupTokenRPM // Anthropic OAuth/SetupTokenRPM
const allAnthropicOAuthOrSetupToken = computed(() => { const allAnthropicOAuthOrSetupToken = computed(() => {
return ( return (
@ -957,6 +1010,7 @@ const enableRateMultiplier = ref(false)
const enableStatus = ref(false) const enableStatus = ref(false)
const enableGroups = ref(false) const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false) const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
const enableRpmLimit = ref(false) const enableRpmLimit = ref(false)
// State - field values // State - field values
@ -979,6 +1033,7 @@ const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active') const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([]) const groupIds = ref<number[]>([])
const openaiPassthroughEnabled = ref(false) const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const rpmLimitEnabled = ref(false) const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null) const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@ -1005,10 +1060,19 @@ const statusOptions = computed(() => [
{ value: 'active', label: t('common.active') }, { value: 'active', label: t('common.active') },
{ value: 'inactive', label: t('common.inactive') } { value: 'inactive', label: t('common.inactive') }
]) ])
const isOpenAIModelRestrictionDisabled = computed(() => const isOpenAIModelRestrictionDisabled = computed(
allOpenAIPassthroughCapable.value && () =>
enableOpenAIPassthrough.value && allOpenAIPassthroughCapable.value &&
openaiPassthroughEnabled.value enableOpenAIPassthrough.value &&
openaiPassthroughEnabled.value
)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
) )
// Model mapping helpers // Model mapping helpers
@ -1180,6 +1244,14 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.credentials = credentials updates.credentials = credentials
} }
if (enableOpenAIWSMode.value) {
const extra = ensureExtra()
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
openaiOAuthResponsesWebSocketV2Mode.value
)
}
// RPM limit settings ( extra ) // RPM limit settings ( extra )
if (enableRpmLimit.value) { if (enableRpmLimit.value) {
const extra = ensureExtra() const extra = ensureExtra()
@ -1269,6 +1341,7 @@ const handleSubmit = async () => {
enableRateMultiplier.value || enableRateMultiplier.value ||
enableStatus.value || enableStatus.value ||
enableGroups.value || enableGroups.value ||
enableOpenAIWSMode.value ||
enableRpmLimit.value || enableRpmLimit.value ||
userMsgQueueMode.value !== null userMsgQueueMode.value !== null
@ -1361,6 +1434,7 @@ watch(
enableStatus.value = false enableStatus.value = false
enableGroups.value = false enableGroups.value = false
enableOpenAIPassthrough.value = false enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false
enableRpmLimit.value = false enableRpmLimit.value = false
// Reset all values // Reset all values
@ -1379,6 +1453,7 @@ watch(
rateMultiplier.value = 1 rateMultiplier.value = 1
status.value = 'active' status.value = 'active'
groupIds.value = [] groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
rpmLimitEnabled.value = false rpmLimitEnabled.value = false
bulkBaseRpm.value = null bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered' bulkRpmStrategy.value = 'tiered'

View File

@ -2504,6 +2504,7 @@
:allow-multiple="form.platform === 'anthropic'" :allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'" :show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'" :show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
:show-mobile-refresh-token-option="form.platform === 'openai'"
:show-session-token-option="form.platform === 'sora'" :show-session-token-option="form.platform === 'sora'"
:show-access-token-option="form.platform === 'sora'" :show-access-token-option="form.platform === 'sora'"
:platform="form.platform" :platform="form.platform"
@ -2511,6 +2512,7 @@
@generate-url="handleGenerateUrl" @generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth" @cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken" @validate-refresh-token="handleValidateRefreshToken"
@validate-mobile-refresh-token="handleOpenAIValidateMobileRT"
@validate-session-token="handleValidateSessionToken" @validate-session-token="handleValidateSessionToken"
@import-access-token="handleImportAccessToken" @import-access-token="handleImportAccessToken"
/> />
@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => {
} }
// OpenAI RT // OpenAI RT
const handleOpenAIValidateRT = async (refreshTokenInput: string) => { // OpenAI Mobile RT 使 client_id openai.SoraClientID
const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK'
// OpenAI/Sora RT
const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => {
const oauthClient = activeOpenAIOAuth.value const oauthClient = activeOpenAIOAuth.value
if (!refreshTokenInput.trim()) return if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput const refreshTokens = refreshTokenInput
.split('\n') .split('\n')
.map((rt) => rt.trim()) .map((rt) => rt.trim())
@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
try { try {
const tokenInfo = await oauthClient.validateRefreshToken( const tokenInfo = await oauthClient.validateRefreshToken(
refreshTokens[i], refreshTokens[i],
form.proxy_id form.proxy_id,
clientId
) )
if (!tokenInfo) { if (!tokenInfo) {
failedCount++ failedCount++
@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
const credentials = oauthClient.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
if (clientId) {
credentials.client_id = clientId
}
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra) const extra = buildOpenAIExtra(oauthExtra)
@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
} }
// Generate account name with index for batch // Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName
let openaiAccountId: string | number | undefined let openaiAccountId: string | number | undefined
@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
} }
} }
// RTCodex CLI client_id
const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt)
// Mobile RTSoraClientID
const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID)
// Sora ST // Sora ST
const handleSoraValidateST = async (sessionTokenInput: string) => { const handleSoraValidateST = async (sessionTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value const oauthClient = activeOpenAIOAuth.value

View File

@ -48,6 +48,17 @@
t(getOAuthKey('refreshTokenAuth')) t(getOAuthKey('refreshTokenAuth'))
}}</span> }}</span>
</label> </label>
<label v-if="showMobileRefreshTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
value="mobile_refresh_token"
class="text-blue-600 focus:ring-blue-500"
/>
<span class="text-sm text-blue-900 dark:text-blue-200">{{
t('admin.accounts.oauth.openai.mobileRefreshTokenAuth', '手动输入 Mobile RT')
}}</span>
</label>
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2"> <label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
<input <input
v-model="inputMethod" v-model="inputMethod"
@ -73,8 +84,8 @@
</div> </div>
</div> </div>
<!-- Refresh Token Input (OpenAI / Antigravity) --> <!-- Refresh Token Input (OpenAI / Antigravity / Mobile RT) -->
<div v-if="inputMethod === 'refresh_token'" class="space-y-4"> <div v-if="inputMethod === 'refresh_token' || inputMethod === 'mobile_refresh_token'" class="space-y-4">
<div <div
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80" class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
> >
@ -759,6 +770,7 @@ interface Props {
methodLabel?: string methodLabel?: string
showCookieOption?: boolean // Whether to show cookie auto-auth option showCookieOption?: boolean // Whether to show cookie auto-auth option
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only)
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only) showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
showAccessTokenOption?: boolean // Whether to show access token input option (Sora only) showAccessTokenOption?: boolean // Whether to show access token input option (Sora only)
platform?: AccountPlatform // Platform type for different UI/text platform?: AccountPlatform // Platform type for different UI/text
@ -776,6 +788,7 @@ const props = withDefaults(defineProps<Props>(), {
methodLabel: 'Authorization Method', methodLabel: 'Authorization Method',
showCookieOption: true, showCookieOption: true,
showRefreshTokenOption: false, showRefreshTokenOption: false,
showMobileRefreshTokenOption: false,
showSessionTokenOption: false, showSessionTokenOption: false,
showAccessTokenOption: false, showAccessTokenOption: false,
platform: 'anthropic', platform: 'anthropic',
@ -787,6 +800,7 @@ const emit = defineEmits<{
'exchange-code': [code: string] 'exchange-code': [code: string]
'cookie-auth': [sessionKey: string] 'cookie-auth': [sessionKey: string]
'validate-refresh-token': [refreshToken: string] 'validate-refresh-token': [refreshToken: string]
'validate-mobile-refresh-token': [refreshToken: string]
'validate-session-token': [sessionToken: string] 'validate-session-token': [sessionToken: string]
'import-access-token': [accessToken: string] 'import-access-token': [accessToken: string]
'update:inputMethod': [method: AuthInputMethod] 'update:inputMethod': [method: AuthInputMethod]
@ -834,7 +848,7 @@ const oauthState = ref('')
const projectId = ref('') const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled // Computed: show method selection when either cookie or refresh token option is enabled
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
// Clipboard // Clipboard
const { copied, copyToClipboard } = useClipboard() const { copied, copyToClipboard } = useClipboard()
@ -945,7 +959,11 @@ const handleCookieAuth = () => {
const handleValidateRefreshToken = () => { const handleValidateRefreshToken = () => {
if (refreshTokenInput.value.trim()) { if (refreshTokenInput.value.trim()) {
emit('validate-refresh-token', refreshTokenInput.value.trim()) if (inputMethod.value === 'mobile_refresh_token') {
emit('validate-mobile-refresh-token', refreshTokenInput.value.trim())
} else {
emit('validate-refresh-token', refreshTokenInput.value.trim())
}
} }
} }

View File

@ -149,6 +149,35 @@ describe('BulkEditAccountModal', () => {
}) })
}) })
it('OpenAI OAuth 批量编辑应提交 OAuth 专属 WS mode 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
})
await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-ws-mode-select"]').setValue('passthrough')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
openai_oauth_responses_websockets_v2_mode: 'passthrough',
openai_oauth_responses_websockets_v2_enabled: true
}
})
})
it('OpenAI API Key 批量编辑不显示 WS mode 入口', () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['apikey']
})
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
})
it('OpenAI 账号批量编辑可关闭自动透传', async () => { it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({ const wrapper = mountModal({
selectedPlatforms: ['openai'], selectedPlatforms: ['openai'],

View File

@ -10,6 +10,7 @@
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" /> <Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" /> <Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" /> <Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
<Select :model-value="filters.privacy_mode" class="w-40" :options="privacyOpts" @update:model-value="updatePrivacyMode" @change="$emit('change')" />
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" /> <Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
</div> </div>
</template> </template>
@ -22,10 +23,18 @@ const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); co
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) } const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) } const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) } const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }]) const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
const privacyOpts = computed(() => [
{ value: '', label: t('admin.accounts.allPrivacyModes') },
{ value: '__unset__', label: t('admin.accounts.privacyUnset') },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
const gOpts = computed(() => [ const gOpts = computed(() => [
{ value: '', label: t('admin.accounts.allGroups') }, { value: '', label: t('admin.accounts.allGroups') },
{ value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') }, { value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') },

View File

@ -0,0 +1,56 @@
import { describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import AccountTableFilters from '../AccountTableFilters.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('AccountTableFilters', () => {
it('renders privacy mode options and emits privacy_mode updates', async () => {
const wrapper = mount(AccountTableFilters, {
props: {
searchQuery: '',
filters: {
platform: '',
type: '',
status: '',
group: '',
privacy_mode: ''
},
groups: []
},
global: {
stubs: {
SearchInput: {
template: '<div />'
},
Select: {
props: ['modelValue', 'options'],
emits: ['update:modelValue', 'change'],
template: '<div class="select-stub" :data-options="JSON.stringify(options)" />'
}
}
}
})
const selects = wrapper.findAll('.select-stub')
expect(selects).toHaveLength(5)
const privacyOptions = JSON.parse(selects[3].attributes('data-options'))
expect(privacyOptions).toEqual([
{ value: '', label: 'admin.accounts.allPrivacyModes' },
{ value: '__unset__', label: 'admin.accounts.privacyUnset' },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
})
})

View File

@ -0,0 +1,141 @@
<script setup lang="ts">
import { computed, onBeforeUnmount, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { useClipboard } from '@/composables/useClipboard'
import type { CustomEndpoint } from '@/types'
const props = defineProps<{
apiBaseUrl: string
customEndpoints: CustomEndpoint[]
}>()
const { t } = useI18n()
const { copyToClipboard } = useClipboard()
const copiedEndpoint = ref<string | null>(null)
let copiedResetTimer: number | undefined
const allEndpoints = computed(() => {
const items: Array<{ name: string; endpoint: string; description: string; isDefault: boolean }> = []
if (props.apiBaseUrl) {
items.push({
name: t('keys.endpoints.title'),
endpoint: props.apiBaseUrl,
description: '',
isDefault: true,
})
}
for (const ep of props.customEndpoints) {
items.push({ ...ep, isDefault: false })
}
return items
})
async function copy(url: string) {
const success = await copyToClipboard(url, t('keys.endpoints.copied'))
if (!success) return
copiedEndpoint.value = url
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
copiedResetTimer = window.setTimeout(() => {
if (copiedEndpoint.value === url) {
copiedEndpoint.value = null
}
}, 1800)
}
function tooltipHint(endpoint: string): string {
return copiedEndpoint.value === endpoint
? t('keys.endpoints.copiedHint')
: t('keys.endpoints.clickToCopy')
}
function speedTestUrl(endpoint: string): string {
return `https://www.tcptest.cn/http/${encodeURIComponent(endpoint)}`
}
onBeforeUnmount(() => {
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
})
</script>
<template>
<div v-if="allEndpoints.length > 0" class="flex flex-wrap gap-2">
<div
v-for="(item, index) in allEndpoints"
:key="index"
class="flex items-center gap-1.5 rounded-lg border border-gray-200 bg-white px-2.5 py-1.5 text-xs transition-colors hover:border-primary-200 dark:border-dark-600 dark:bg-dark-800 dark:hover:border-primary-700"
>
<span class="font-medium text-gray-600 dark:text-gray-300">{{ item.name }}</span>
<span
v-if="item.isDefault"
class="rounded bg-primary-50 px-1 py-px text-[10px] font-medium leading-tight text-primary-600 dark:bg-primary-900/30 dark:text-primary-400"
>{{ t('keys.endpoints.default') }}</span>
<span class="text-gray-300 dark:text-dark-500">|</span>
<div class="group/endpoint relative flex items-center gap-1.5">
<div
class="pointer-events-none absolute bottom-full left-1/2 z-20 mb-2 w-max max-w-[24rem] -translate-x-1/2 translate-y-1 rounded-xl border border-slate-200 bg-white px-3 py-2.5 text-left opacity-0 shadow-[0_14px_36px_-20px_rgba(15,23,42,0.35)] ring-1 ring-slate-200/80 transition-all duration-150 group-hover/endpoint:translate-y-0 group-hover/endpoint:opacity-100 group-focus-within/endpoint:translate-y-0 group-focus-within/endpoint:opacity-100 dark:border-slate-700 dark:bg-slate-900 dark:ring-slate-700/70"
>
<p
v-if="item.description"
class="max-w-[24rem] break-words text-xs leading-5 text-slate-600 dark:text-slate-200"
>
{{ item.description }}
</p>
<p
class="flex items-center gap-1.5 text-[11px] leading-4 text-primary-600 dark:text-primary-300"
:class="item.description ? 'mt-1.5' : ''"
>
<span class="h-1.5 w-1.5 rounded-full bg-primary-500 dark:bg-primary-300"></span>
{{ tooltipHint(item.endpoint) }}
</p>
<div class="absolute left-1/2 top-full h-3 w-3 -translate-x-1/2 -translate-y-1/2 rotate-45 border-b border-r border-slate-200 bg-white dark:border-slate-700 dark:bg-slate-900"></div>
</div>
<code
class="cursor-pointer font-mono text-gray-500 decoration-gray-400 decoration-dashed underline-offset-2 hover:text-primary-600 hover:underline focus:text-primary-600 focus:underline focus:outline-none dark:text-gray-400 dark:decoration-gray-500 dark:hover:text-primary-400 dark:focus:text-primary-400"
role="button"
tabindex="0"
@click="copy(item.endpoint)"
@keydown.enter.prevent="copy(item.endpoint)"
@keydown.space.prevent="copy(item.endpoint)"
>{{ item.endpoint }}</code>
<button
type="button"
class="rounded p-0.5 transition-colors"
:class="copiedEndpoint === item.endpoint
? 'text-emerald-500 dark:text-emerald-400'
: 'text-gray-400 hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400'"
:aria-label="tooltipHint(item.endpoint)"
@click="copy(item.endpoint)"
>
<svg v-if="copiedEndpoint === item.endpoint" class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2.2">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
<svg v-else class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z" />
</svg>
</button>
<a
:href="speedTestUrl(item.endpoint)"
target="_blank"
rel="noopener noreferrer"
class="rounded p-0.5 text-gray-400 transition-colors hover:text-amber-500 dark:text-gray-500 dark:hover:text-amber-400"
:title="t('keys.endpoints.speedTest')"
>
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M13 10V3L4 14h7v7l9-11h-7z" />
</svg>
</a>
</div>
</div>
</div>
</template>

View File

@ -0,0 +1,69 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
const copyToClipboard = vi.fn().mockResolvedValue(true)
const messages: Record<string, string> = {
'keys.endpoints.title': 'API 端点',
'keys.endpoints.default': '默认',
'keys.endpoints.copied': '已复制',
'keys.endpoints.copiedHint': '已复制到剪贴板',
'keys.endpoints.clickToCopy': '点击可复制此端点',
'keys.endpoints.speedTest': '测速',
}
vi.mock('vue-i18n', () => ({
useI18n: () => ({
t: (key: string) => messages[key] ?? key,
}),
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copyToClipboard,
}),
}))
import EndpointPopover from '../EndpointPopover.vue'
describe('EndpointPopover', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('将说明提示渲染到 URL 上方而不是旧的 title 图标上', () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [
{
name: '备用线路',
endpoint: 'https://backup.example.com/v1',
description: '自定义说明',
},
],
},
})
expect(wrapper.text()).toContain('自定义说明')
expect(wrapper.text()).toContain('点击可复制此端点')
expect(wrapper.find('[role="button"]').attributes('title')).toBeUndefined()
expect(wrapper.find('[title="自定义说明"]').exists()).toBe(false)
})
it('点击 URL 后会复制并切换为已复制提示', async () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [],
},
})
await wrapper.find('[role="button"]').trigger('click')
await flushPromises()
expect(copyToClipboard).toHaveBeenCalledWith('https://default.example.com/v1', '已复制')
expect(wrapper.text()).toContain('已复制到剪贴板')
expect(wrapper.find('button[aria-label="已复制到剪贴板"]').exists()).toBe(true)
})
})

View File

@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
export type AddMethod = 'oauth' | 'setup-token' export type AddMethod = 'oauth' | 'setup-token'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token' export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token'
export interface OAuthState { export interface OAuthState {
authUrl: string authUrl: string

View File

@ -13,6 +13,8 @@ export interface OpenAITokenInfo {
scope?: string scope?: string
email?: string email?: string
name?: string name?: string
plan_type?: string
privacy_mode?: string
// OpenAI specific IDs (extracted from ID Token) // OpenAI specific IDs (extracted from ID Token)
chatgpt_account_id?: string chatgpt_account_id?: string
chatgpt_user_id?: string chatgpt_user_id?: string
@ -126,9 +128,11 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
} }
// Validate refresh token and get full token info // Validate refresh token and get full token info
// clientId: 指定 OAuth client_id用于第三方渠道获取的 RT如 app_LlGpXReQgckcGGUo2JrYvtJK
const validateRefreshToken = async ( const validateRefreshToken = async (
refreshToken: string, refreshToken: string,
proxyId?: number | null proxyId?: number | null,
clientId?: string
): Promise<OpenAITokenInfo | null> => { ): Promise<OpenAITokenInfo | null> => {
if (!refreshToken.trim()) { if (!refreshToken.trim()) {
error.value = 'Missing refresh token' error.value = 'Missing refresh token'
@ -143,11 +147,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken( const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
refreshToken.trim(), refreshToken.trim(),
proxyId, proxyId,
`${endpointPrefix}/refresh-token` `${endpointPrefix}/refresh-token`,
clientId
) )
return tokenInfo as OpenAITokenInfo return tokenInfo as OpenAITokenInfo
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate refresh token' error.value = err.response?.data?.detail || err.message || 'Failed to validate refresh token'
appStore.showError(error.value) appStore.showError(error.value)
return null return null
} finally { } finally {
@ -182,22 +187,23 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
} }
} }
// Build credentials for OpenAI OAuth account // Build credentials for OpenAI OAuth account (aligned with backend BuildAccountCredentials)
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => { const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
const creds: Record<string, unknown> = { const creds: Record<string, unknown> = {
access_token: tokenInfo.access_token, access_token: tokenInfo.access_token,
refresh_token: tokenInfo.refresh_token, expires_at: tokenInfo.expires_at
token_type: tokenInfo.token_type,
expires_in: tokenInfo.expires_in,
expires_at: tokenInfo.expires_at,
scope: tokenInfo.scope
} }
if (tokenInfo.client_id) { // 仅在返回了新的 refresh_token 时才写入,防止用空值覆盖已有令牌
creds.client_id = tokenInfo.client_id if (tokenInfo.refresh_token) {
creds.refresh_token = tokenInfo.refresh_token
}
if (tokenInfo.id_token) {
creds.id_token = tokenInfo.id_token
}
if (tokenInfo.email) {
creds.email = tokenInfo.email
} }
// Include OpenAI specific IDs (required for forwarding)
if (tokenInfo.chatgpt_account_id) { if (tokenInfo.chatgpt_account_id) {
creds.chatgpt_account_id = tokenInfo.chatgpt_account_id creds.chatgpt_account_id = tokenInfo.chatgpt_account_id
} }
@ -207,6 +213,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.organization_id) { if (tokenInfo.organization_id) {
creds.organization_id = tokenInfo.organization_id creds.organization_id = tokenInfo.organization_id
} }
if (tokenInfo.plan_type) {
creds.plan_type = tokenInfo.plan_type
}
if (tokenInfo.client_id) {
creds.client_id = tokenInfo.client_id
}
return creds return creds
} }
@ -220,6 +232,9 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.name) { if (tokenInfo.name) {
extra.name = tokenInfo.name extra.name = tokenInfo.name
} }
if (tokenInfo.privacy_mode) {
extra.privacy_mode = tokenInfo.privacy_mode
}
return Object.keys(extra).length > 0 ? extra : undefined return Object.keys(extra).length > 0 ? extra : undefined
} }

View File

@ -533,6 +533,14 @@ export default {
title: 'API Keys', title: 'API Keys',
description: 'Manage your API keys and access tokens', description: 'Manage your API keys and access tokens',
searchPlaceholder: 'Search name or key...', searchPlaceholder: 'Search name or key...',
endpoints: {
title: 'API Endpoints',
default: 'Default',
copied: 'Copied',
copiedHint: 'Copied to clipboard',
clickToCopy: 'Click to copy this endpoint',
speedTest: 'Speed Test',
},
allGroups: 'All Groups', allGroups: 'All Groups',
allStatus: 'All Status', allStatus: 'All Status',
createKey: 'Create API Key', createKey: 'Create API Key',
@ -1971,6 +1979,8 @@ export default {
expiresAt: 'Expires At', expiresAt: 'Expires At',
actions: 'Actions' actions: 'Actions'
}, },
allPrivacyModes: 'All Privacy States',
privacyUnset: 'Unset',
privacyTrainingOff: 'Training data sharing disabled', privacyTrainingOff: 'Training data sharing disabled',
privacyCfBlocked: 'Blocked by Cloudflare, training may still be on', privacyCfBlocked: 'Blocked by Cloudflare, training may still be on',
privacyFailed: 'Failed to disable training', privacyFailed: 'Failed to disable training',
@ -3486,7 +3496,12 @@ export default {
typeRequest: 'Request', typeRequest: 'Request',
typeAuth: 'Auth', typeAuth: 'Auth',
typeRouting: 'Routing', typeRouting: 'Routing',
typeInternal: 'Internal' typeInternal: 'Internal',
endpoint: 'Endpoint',
requestType: 'Type',
requestTypeSync: 'Sync',
requestTypeStream: 'Stream',
requestTypeWs: 'WS'
}, },
// Error Details Modal // Error Details Modal
errorDetails: { errorDetails: {
@ -3572,6 +3587,16 @@ export default {
latency: 'Request Duration', latency: 'Request Duration',
businessLimited: 'Business Limited', businessLimited: 'Business Limited',
requestPath: 'Request Path', requestPath: 'Request Path',
inboundEndpoint: 'Inbound Endpoint',
upstreamEndpoint: 'Upstream Endpoint',
requestedModel: 'Requested Model',
upstreamModel: 'Upstream Model',
requestType: 'Request Type',
requestTypeUnknown: 'Unknown',
requestTypeSync: 'Sync',
requestTypeStream: 'Stream',
requestTypeWs: 'WebSocket',
modelMapping: 'Model Mapping',
timings: 'Timings', timings: 'Timings',
auth: 'Auth', auth: 'Auth',
routing: 'Routing', routing: 'Routing',
@ -4162,6 +4187,18 @@ export default {
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
apiBaseUrlHint: apiBaseUrlHint:
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.', 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
customEndpoints: {
title: 'Custom Endpoints',
description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page',
itemLabel: 'Endpoint #{n}',
name: 'Name',
namePlaceholder: 'e.g., OpenAI Compatible',
endpointUrl: 'Endpoint URL',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: 'Description',
descriptionPlaceholder: 'e.g., Supports OpenAI format requests',
add: 'Add Endpoint',
},
contactInfo: 'Contact Info', contactInfo: 'Contact Info',
contactInfoPlaceholder: 'e.g., QQ: 123456789', contactInfoPlaceholder: 'e.g., QQ: 123456789',
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.', contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',

View File

@ -533,6 +533,14 @@ export default {
title: 'API 密钥', title: 'API 密钥',
description: '管理您的 API 密钥和访问令牌', description: '管理您的 API 密钥和访问令牌',
searchPlaceholder: '搜索名称或Key...', searchPlaceholder: '搜索名称或Key...',
endpoints: {
title: 'API 端点',
default: '默认',
copied: '已复制',
copiedHint: '已复制到剪贴板',
clickToCopy: '点击可复制此端点',
speedTest: '测速',
},
allGroups: '全部分组', allGroups: '全部分组',
allStatus: '全部状态', allStatus: '全部状态',
createKey: '创建密钥', createKey: '创建密钥',
@ -2009,6 +2017,8 @@ export default {
expiresAt: '过期时间', expiresAt: '过期时间',
actions: '操作' actions: '操作'
}, },
allPrivacyModes: '全部Privacy状态',
privacyUnset: '未设置',
privacyTrainingOff: '已关闭训练数据共享', privacyTrainingOff: '已关闭训练数据共享',
privacyCfBlocked: '被 Cloudflare 拦截,训练可能仍开启', privacyCfBlocked: '被 Cloudflare 拦截,训练可能仍开启',
privacyFailed: '关闭训练数据共享失败', privacyFailed: '关闭训练数据共享失败',
@ -3651,7 +3661,12 @@ export default {
typeRequest: '请求', typeRequest: '请求',
typeAuth: '认证', typeAuth: '认证',
typeRouting: '路由', typeRouting: '路由',
typeInternal: '内部' typeInternal: '内部',
endpoint: '端点',
requestType: '类型',
requestTypeSync: '同步',
requestTypeStream: '流式',
requestTypeWs: 'WS'
}, },
// Error Details Modal // Error Details Modal
errorDetails: { errorDetails: {
@ -3737,6 +3752,16 @@ export default {
latency: '请求时长', latency: '请求时长',
businessLimited: '业务限制', businessLimited: '业务限制',
requestPath: '请求路径', requestPath: '请求路径',
inboundEndpoint: '入站端点',
upstreamEndpoint: '上游端点',
requestedModel: '请求模型',
upstreamModel: '上游模型',
requestType: '请求类型',
requestTypeUnknown: '未知',
requestTypeSync: '同步',
requestTypeStream: '流式',
requestTypeWs: 'WebSocket',
modelMapping: '模型映射',
timings: '时序信息', timings: '时序信息',
auth: '认证', auth: '认证',
routing: '路由', routing: '路由',
@ -4324,6 +4349,18 @@ export default {
apiBaseUrl: 'API 端点地址', apiBaseUrl: 'API 端点地址',
apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址', apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址',
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
customEndpoints: {
title: '自定义端点',
description: '添加额外的 API 端点地址用户可在「API Keys」页面快速复制',
itemLabel: '端点 #{n}',
name: '名称',
namePlaceholder: '如OpenAI Compatible',
endpointUrl: '端点地址',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: '介绍',
descriptionPlaceholder: '如:支持 OpenAI 格式请求',
add: '添加端点',
},
contactInfo: '客服联系方式', contactInfo: '客服联系方式',
contactInfoPlaceholder: '例如QQ: 123456789', contactInfoPlaceholder: '例如QQ: 123456789',
contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置', contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置',

View File

@ -330,6 +330,7 @@ export const useAppStore = defineStore('app', () => {
purchase_subscription_enabled: false, purchase_subscription_enabled: false,
purchase_subscription_url: '', purchase_subscription_url: '',
custom_menu_items: [], custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false, linuxdo_oauth_enabled: false,
sora_client_enabled: false, sora_client_enabled: false,
backend_mode_enabled: false, backend_mode_enabled: false,

View File

@ -84,6 +84,12 @@ export interface CustomMenuItem {
sort_order: number sort_order: number
} }
export interface CustomEndpoint {
name: string
endpoint: string
description: string
}
export interface PublicSettings { export interface PublicSettings {
registration_enabled: boolean registration_enabled: boolean
email_verify_enabled: boolean email_verify_enabled: boolean
@ -104,6 +110,7 @@ export interface PublicSettings {
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean linuxdo_oauth_enabled: boolean
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean backend_mode_enabled: boolean

View File

@ -581,7 +581,7 @@ const {
handlePageSizeChange: baseHandlePageSizeChange handlePageSizeChange: baseHandlePageSizeChange
} = useTableLoader<Account, any>({ } = useTableLoader<Account, any>({
fetchFn: adminAPI.accounts.list, fetchFn: adminAPI.accounts.list,
initialParams: { platform: '', type: '', status: '', group: '', search: '' } initialParams: { platform: '', type: '', status: '', privacy_mode: '', group: '', search: '' }
}) })
const { const {
@ -758,6 +758,7 @@ const refreshAccountsIncrementally = async () => {
platform?: string platform?: string
type?: string type?: string
status?: string status?: string
privacy_mode?: string
group?: string group?: string
search?: string search?: string

View File

@ -1248,6 +1248,81 @@
</p> </p>
</div> </div>
<!-- Custom Endpoints -->
<div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.title') }}
</label>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.description') }}
</p>
<div class="space-y-3">
<div
v-for="(ep, index) in form.custom_endpoints"
:key="index"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center justify-between">
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }}
</span>
<button
type="button"
class="rounded p-1 text-red-400 hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
@click="removeEndpoint(index)"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" /></svg>
</button>
</div>
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.name') }}
</label>
<input
v-model="ep.name"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.namePlaceholder')"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.endpointUrl') }}
</label>
<input
v-model="ep.endpoint"
type="url"
class="input font-mono text-sm"
:placeholder="t('admin.settings.site.customEndpoints.endpointUrlPlaceholder')"
/>
</div>
<div class="sm:col-span-2">
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.descriptionLabel') }}
</label>
<input
v-model="ep.description"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.descriptionPlaceholder')"
/>
</div>
</div>
</div>
</div>
<button
type="button"
class="mt-3 flex w-full items-center justify-center gap-2 rounded-lg border-2 border-dashed border-gray-300 px-4 py-2.5 text-sm text-gray-500 transition-colors hover:border-primary-400 hover:text-primary-600 dark:border-dark-600 dark:text-gray-400 dark:hover:border-primary-500 dark:hover:text-primary-400"
@click="addEndpoint"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M12 4v16m8-8H4" /></svg>
{{ t('admin.settings.site.customEndpoints.add') }}
</button>
</div>
<!-- Contact Info --> <!-- Contact Info -->
<div> <div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"> <label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
@ -1580,7 +1655,7 @@
<button <button
type="button" type="button"
@click="testSmtpConnection" @click="testSmtpConnection"
:disabled="testingSmtp" :disabled="testingSmtp || loadFailed"
class="btn btn-secondary btn-sm" class="btn btn-secondary btn-sm"
> >
<svg v-if="testingSmtp" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24"> <svg v-if="testingSmtp" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
@ -1650,6 +1725,11 @@
v-model="form.smtp_password" v-model="form.smtp_password"
type="password" type="password"
class="input" class="input"
autocomplete="new-password"
autocapitalize="off"
spellcheck="false"
@keydown="smtpPasswordManuallyEdited = true"
@paste="smtpPasswordManuallyEdited = true"
:placeholder=" :placeholder="
form.smtp_password_configured form.smtp_password_configured
? t('admin.settings.smtp.passwordConfiguredPlaceholder') ? t('admin.settings.smtp.passwordConfiguredPlaceholder')
@ -1732,7 +1812,7 @@
<button <button
type="button" type="button"
@click="sendTestEmail" @click="sendTestEmail"
:disabled="sendingTestEmail || !testEmailAddress" :disabled="sendingTestEmail || !testEmailAddress || loadFailed"
class="btn btn-secondary" class="btn btn-secondary"
> >
<svg <svg
@ -1778,7 +1858,7 @@
<!-- Save Button --> <!-- Save Button -->
<div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end"> <div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end">
<button type="submit" :disabled="saving" class="btn btn-primary"> <button type="submit" :disabled="saving || loadFailed" class="btn btn-primary">
<svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24"> <svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle <circle
class="opacity-25" class="opacity-25"
@ -1849,9 +1929,11 @@ const settingsTabs = [
const { copyToClipboard } = useClipboard() const { copyToClipboard } = useClipboard()
const loading = ref(true) const loading = ref(true)
const loadFailed = ref(false)
const saving = ref(false) const saving = ref(false)
const testingSmtp = ref(false) const testingSmtp = ref(false)
const sendingTestEmail = ref(false) const sendingTestEmail = ref(false)
const smtpPasswordManuallyEdited = ref(false)
const testEmailAddress = ref('') const testEmailAddress = ref('')
const registrationEmailSuffixWhitelistTags = ref<string[]>([]) const registrationEmailSuffixWhitelistTags = ref<string[]>([])
const registrationEmailSuffixWhitelistDraft = ref('') const registrationEmailSuffixWhitelistDraft = ref('')
@ -1945,6 +2027,7 @@ const form = reactive<SettingsForm>({
purchase_subscription_url: '', purchase_subscription_url: '',
sora_client_enabled: false, sora_client_enabled: false,
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
frontend_url: '', frontend_url: '',
smtp_host: '', smtp_host: '',
smtp_port: 587, smtp_port: 587,
@ -2114,8 +2197,18 @@ function moveMenuItem(index: number, direction: -1 | 1) {
}) })
} }
// Custom endpoint management
function addEndpoint() {
form.custom_endpoints.push({ name: '', endpoint: '', description: '' })
}
function removeEndpoint(index: number) {
form.custom_endpoints.splice(index, 1)
}
async function loadSettings() { async function loadSettings() {
loading.value = true loading.value = true
loadFailed.value = false
try { try {
const settings = await adminAPI.settings.getSettings() const settings = await adminAPI.settings.getSettings()
Object.assign(form, settings) Object.assign(form, settings)
@ -2133,9 +2226,11 @@ async function loadSettings() {
) )
registrationEmailSuffixWhitelistDraft.value = '' registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = '' form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = '' form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = '' form.linuxdo_connect_client_secret = ''
} catch (error: any) { } catch (error: any) {
loadFailed.value = true
appStore.showError( appStore.showError(
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError')) t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
) )
@ -2253,6 +2348,7 @@ async function saveSettings() {
purchase_subscription_url: form.purchase_subscription_url, purchase_subscription_url: form.purchase_subscription_url,
sora_client_enabled: form.sora_client_enabled, sora_client_enabled: form.sora_client_enabled,
custom_menu_items: form.custom_menu_items, custom_menu_items: form.custom_menu_items,
custom_endpoints: form.custom_endpoints,
frontend_url: form.frontend_url, frontend_url: form.frontend_url,
smtp_host: form.smtp_host, smtp_host: form.smtp_host,
smtp_port: form.smtp_port, smtp_port: form.smtp_port,
@ -2286,6 +2382,7 @@ async function saveSettings() {
) )
registrationEmailSuffixWhitelistDraft.value = '' registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = '' form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = '' form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = '' form.linuxdo_connect_client_secret = ''
// Refresh cached settings so sidebar/header update immediately // Refresh cached settings so sidebar/header update immediately
@ -2304,11 +2401,12 @@ async function saveSettings() {
async function testSmtpConnection() { async function testSmtpConnection() {
testingSmtp.value = true testingSmtp.value = true
try { try {
const smtpPasswordForTest = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
const result = await adminAPI.settings.testSmtpConnection({ const result = await adminAPI.settings.testSmtpConnection({
smtp_host: form.smtp_host, smtp_host: form.smtp_host,
smtp_port: form.smtp_port, smtp_port: form.smtp_port,
smtp_username: form.smtp_username, smtp_username: form.smtp_username,
smtp_password: form.smtp_password, smtp_password: smtpPasswordForTest,
smtp_use_tls: form.smtp_use_tls smtp_use_tls: form.smtp_use_tls
}) })
// API returns { message: "..." } on success, errors are thrown as exceptions // API returns { message: "..." } on success, errors are thrown as exceptions
@ -2330,12 +2428,13 @@ async function sendTestEmail() {
sendingTestEmail.value = true sendingTestEmail.value = true
try { try {
const smtpPasswordForSend = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
const result = await adminAPI.settings.sendTestEmail({ const result = await adminAPI.settings.sendTestEmail({
email: testEmailAddress.value, email: testEmailAddress.value,
smtp_host: form.smtp_host, smtp_host: form.smtp_host,
smtp_port: form.smtp_port, smtp_port: form.smtp_port,
smtp_username: form.smtp_username, smtp_username: form.smtp_username,
smtp_password: form.smtp_password, smtp_password: smtpPasswordForSend,
smtp_from_email: form.smtp_from_email, smtp_from_email: form.smtp_from_email,
smtp_from_name: form.smtp_from_name, smtp_from_name: form.smtp_from_name,
smtp_use_tls: form.smtp_use_tls smtp_use_tls: form.smtp_use_tls

View File

@ -59,7 +59,28 @@
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900"> <div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.model') }}</div> <div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.model') }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white"> <div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
{{ detail.model || '—' }} <template v-if="hasModelMapping(detail)">
<span class="font-mono">{{ detail.requested_model }}</span>
<span class="mx-1 text-gray-400"></span>
<span class="font-mono text-primary-600 dark:text-primary-400">{{ detail.upstream_model }}</span>
</template>
<template v-else>
{{ displayModel(detail) || '—' }}
</template>
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.inboundEndpoint') }}</div>
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
{{ detail.inbound_endpoint || '—' }}
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.upstreamEndpoint') }}</div>
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
{{ detail.upstream_endpoint || '—' }}
</div> </div>
</div> </div>
@ -72,6 +93,13 @@
</div> </div>
</div> </div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.requestType') }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
{{ formatRequestTypeLabel(detail.request_type) }}
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900"> <div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.message') }}</div> <div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.message') }}</div>
<div class="mt-1 truncate text-sm font-medium text-gray-900 dark:text-white" :title="detail.message"> <div class="mt-1 truncate text-sm font-medium text-gray-900 dark:text-white" :title="detail.message">
@ -213,6 +241,31 @@ function isUpstreamError(d: OpsErrorDetail | null): boolean {
return phase === 'upstream' && owner === 'provider' return phase === 'upstream' && owner === 'provider'
} }
function formatRequestTypeLabel(type: number | null | undefined): string {
switch (type) {
case 1: return t('admin.ops.errorDetail.requestTypeSync')
case 2: return t('admin.ops.errorDetail.requestTypeStream')
case 3: return t('admin.ops.errorDetail.requestTypeWs')
default: return t('admin.ops.errorDetail.requestTypeUnknown')
}
}
function hasModelMapping(d: OpsErrorDetail | null): boolean {
if (!d) return false
const requested = String(d.requested_model || '').trim()
const upstream = String(d.upstream_model || '').trim()
return !!requested && !!upstream && requested !== upstream
}
function displayModel(d: OpsErrorDetail | null): string {
if (!d) return ''
const upstream = String(d.upstream_model || '').trim()
if (upstream) return upstream
const requested = String(d.requested_model || '').trim()
if (requested) return requested
return String(d.model || '').trim()
}
const correlatedUpstream = ref<OpsErrorDetail[]>([]) const correlatedUpstream = ref<OpsErrorDetail[]>([])
const correlatedUpstreamLoading = ref(false) const correlatedUpstreamLoading = ref(false)

View File

@ -17,6 +17,9 @@
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400"> <th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.type') }} {{ t('admin.ops.errorLog.type') }}
</th> </th>
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.endpoint') }}
</th>
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400"> <th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.platform') }} {{ t('admin.ops.errorLog.platform') }}
</th> </th>
@ -42,7 +45,7 @@
</thead> </thead>
<tbody class="divide-y divide-gray-100 dark:divide-dark-700"> <tbody class="divide-y divide-gray-100 dark:divide-dark-700">
<tr v-if="rows.length === 0"> <tr v-if="rows.length === 0">
<td colspan="9" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500"> <td colspan="10" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500">
{{ t('admin.ops.errorLog.noErrors') }} {{ t('admin.ops.errorLog.noErrors') }}
</td> </td>
</tr> </tr>
@ -74,6 +77,18 @@
</span> </span>
</td> </td>
<!-- Endpoint -->
<td class="px-4 py-2">
<div class="max-w-[160px]">
<el-tooltip v-if="log.inbound_endpoint" :content="formatEndpointTooltip(log)" placement="top" :show-after="500">
<span class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
{{ log.inbound_endpoint }}
</span>
</el-tooltip>
<span v-else class="text-xs text-gray-400">-</span>
</div>
</td>
<!-- Platform --> <!-- Platform -->
<td class="whitespace-nowrap px-4 py-2"> <td class="whitespace-nowrap px-4 py-2">
<span class="inline-flex items-center rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold uppercase text-gray-600 dark:bg-dark-700 dark:text-gray-300"> <span class="inline-flex items-center rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold uppercase text-gray-600 dark:bg-dark-700 dark:text-gray-300">
@ -83,11 +98,22 @@
<!-- Model --> <!-- Model -->
<td class="px-4 py-2"> <td class="px-4 py-2">
<div class="max-w-[120px] truncate" :title="log.model"> <div class="max-w-[160px]">
<span v-if="log.model" class="font-mono text-[11px] text-gray-700 dark:text-gray-300"> <template v-if="hasModelMapping(log)">
{{ log.model }} <el-tooltip :content="modelMappingTooltip(log)" placement="top" :show-after="500">
</span> <span class="flex items-center gap-1 truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
<span v-else class="text-xs text-gray-400">-</span> <span class="truncate">{{ log.requested_model }}</span>
<span class="flex-shrink-0 text-gray-400"></span>
<span class="truncate text-primary-600 dark:text-primary-400">{{ log.upstream_model }}</span>
</span>
</el-tooltip>
</template>
<template v-else>
<span v-if="displayModel(log)" class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300" :title="displayModel(log)">
{{ displayModel(log) }}
</span>
<span v-else class="text-xs text-gray-400">-</span>
</template>
</div> </div>
</td> </td>
@ -138,6 +164,12 @@
> >
{{ log.severity }} {{ log.severity }}
</span> </span>
<span
v-if="log.request_type != null && log.request_type > 0"
class="rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold text-gray-600 dark:bg-dark-700 dark:text-gray-300"
>
{{ formatRequestType(log.request_type) }}
</span>
</div> </div>
</td> </td>
@ -193,6 +225,44 @@ function isUpstreamRow(log: OpsErrorLog): boolean {
return phase === 'upstream' && owner === 'provider' return phase === 'upstream' && owner === 'provider'
} }
function formatEndpointTooltip(log: OpsErrorLog): string {
const parts: string[] = []
if (log.inbound_endpoint) parts.push(`Inbound: ${log.inbound_endpoint}`)
if (log.upstream_endpoint) parts.push(`Upstream: ${log.upstream_endpoint}`)
return parts.join('\n') || ''
}
function hasModelMapping(log: OpsErrorLog): boolean {
const requested = String(log.requested_model || '').trim()
const upstream = String(log.upstream_model || '').trim()
return !!requested && !!upstream && requested !== upstream
}
function modelMappingTooltip(log: OpsErrorLog): string {
const requested = String(log.requested_model || '').trim()
const upstream = String(log.upstream_model || '').trim()
if (!requested && !upstream) return ''
if (requested && upstream) return `${requested}${upstream}`
return upstream || requested
}
function displayModel(log: OpsErrorLog): string {
const upstream = String(log.upstream_model || '').trim()
if (upstream) return upstream
const requested = String(log.requested_model || '').trim()
if (requested) return requested
return String(log.model || '').trim()
}
function formatRequestType(type: number | null | undefined): string {
switch (type) {
case 1: return t('admin.ops.errorLog.requestTypeSync')
case 2: return t('admin.ops.errorLog.requestTypeStream')
case 3: return t('admin.ops.errorLog.requestTypeWs')
default: return ''
}
}
function getTypeBadge(log: OpsErrorLog): { label: string; className: string } { function getTypeBadge(log: OpsErrorLog): { label: string; className: string } {
const phase = String(log.phase || '').toLowerCase() const phase = String(log.phase || '').toLowerCase()
const owner = String(log.error_owner || '').toLowerCase() const owner = String(log.error_owner || '').toLowerCase()

View File

@ -344,7 +344,7 @@ onMounted(async () => {
<div class="text-xs font-semibold text-gray-700 dark:text-gray-200">运行时日志配置实时生效</div> <div class="text-xs font-semibold text-gray-700 dark:text-gray-200">运行时日志配置实时生效</div>
<span v-if="runtimeLoading" class="text-xs text-gray-500">加载中...</span> <span v-if="runtimeLoading" class="text-xs text-gray-500">加载中...</span>
</div> </div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-6"> <div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
<label class="text-xs text-gray-600 dark:text-gray-300"> <label class="text-xs text-gray-600 dark:text-gray-300">
级别 级别
<select v-model="runtimeConfig.level" class="input mt-1"> <select v-model="runtimeConfig.level" class="input mt-1">
@ -374,21 +374,27 @@ onMounted(async () => {
保留天数 保留天数
<input v-model.number="runtimeConfig.retention_days" type="number" min="1" max="3650" class="input mt-1" /> <input v-model.number="runtimeConfig.retention_days" type="number" min="1" max="3650" class="input mt-1" />
</label> </label>
<div class="flex items-end gap-2"> <div class="md:col-span-2 xl:col-span-6">
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300"> <div class="grid gap-3 lg:grid-cols-[minmax(0,1fr)_auto] lg:items-end">
<input v-model="runtimeConfig.caller" type="checkbox" /> <div class="flex flex-wrap items-center gap-x-4 gap-y-2">
caller <label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
</label> <input v-model="runtimeConfig.caller" type="checkbox" />
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300"> caller
<input v-model="runtimeConfig.enable_sampling" type="checkbox" /> </label>
sampling <label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
</label> <input v-model="runtimeConfig.enable_sampling" type="checkbox" />
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig"> sampling
{{ runtimeSaving ? '保存中...' : '保存并生效' }} </label>
</button> </div>
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig"> <div class="flex flex-wrap items-center gap-2 lg:justify-end">
回滚默认值 <button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
</button> {{ runtimeSaving ? '保存中...' : '保存并生效' }}
</button>
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
回滚默认值
</button>
</div>
</div>
</div> </div>
</div> </div>
<p v-if="health.last_error" class="mt-2 text-xs text-red-600 dark:text-red-400">最近写入错误{{ health.last_error }}</p> <p v-if="health.last_error" class="mt-2 text-xs text-red-600 dark:text-red-400">最近写入错误{{ health.last_error }}</p>

View File

@ -2,24 +2,31 @@
<AppLayout> <AppLayout>
<TablePageLayout> <TablePageLayout>
<template #filters> <template #filters>
<div class="flex flex-wrap items-center gap-3"> <div class="flex flex-col gap-3">
<SearchInput <div class="flex flex-wrap items-center gap-3">
v-model="filterSearch" <SearchInput
:placeholder="t('keys.searchPlaceholder')" v-model="filterSearch"
class="w-full sm:w-64" :placeholder="t('keys.searchPlaceholder')"
@search="onFilterChange" class="w-full sm:w-64"
/> @search="onFilterChange"
<Select />
:model-value="filterGroupId" <Select
class="w-40" :model-value="filterGroupId"
:options="groupFilterOptions" class="w-40"
@update:model-value="onGroupFilterChange" :options="groupFilterOptions"
/> @update:model-value="onGroupFilterChange"
<Select />
:model-value="filterStatus" <Select
class="w-40" :model-value="filterStatus"
:options="statusFilterOptions" class="w-40"
@update:model-value="onStatusFilterChange" :options="statusFilterOptions"
@update:model-value="onStatusFilterChange"
/>
</div>
<EndpointPopover
v-if="publicSettings?.api_base_url || (publicSettings?.custom_endpoints?.length ?? 0) > 0"
:api-base-url="publicSettings?.api_base_url || ''"
:custom-endpoints="publicSettings?.custom_endpoints || []"
/> />
</div> </div>
</template> </template>
@ -1050,6 +1057,7 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import SearchInput from '@/components/common/SearchInput.vue' import SearchInput from '@/components/common/SearchInput.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import UseKeyModal from '@/components/keys/UseKeyModal.vue' import UseKeyModal from '@/components/keys/UseKeyModal.vue'
import EndpointPopover from '@/components/keys/EndpointPopover.vue'
import GroupBadge from '@/components/common/GroupBadge.vue' import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types' import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'