merge: resolve upstream main conflicts for bulk OpenAI passthrough

This commit is contained in:
Wang Lvyuan 2026-03-24 19:27:51 +08:00
commit bb399e56b0
98 changed files with 5168 additions and 213 deletions

View File

@ -61,6 +61,9 @@ temp/
deploy/install.sh
deploy/sub2api.service
deploy/sub2api-sudoers
deploy/data/
deploy/postgres_data/
deploy/redis_data/
# GoReleaser
.goreleaser.yaml

View File

@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()

View File

@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
if err != nil {
return nil, err
}

View File

@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
if err != nil {
response.ErrorFrom(c, err)
return
@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}

View File

@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
@ -176,6 +177,7 @@ type UpdateSettingsRequest struct {
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@ -231,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
req.SMTPHost = previousSettings.SMTPHost
req.SMTPPort = previousSettings.SMTPPort
req.SMTPUsername = previousSettings.SMTPUsername
req.SMTPFrom = previousSettings.SMTPFrom
req.SMTPFromName = previousSettings.SMTPFromName
req.SMTPUseTLS = previousSettings.SMTPUseTLS
}
// Turnstile 参数验证
if req.TurnstileEnabled {
// 检查必填字段
@ -417,6 +435,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
customMenuJSON = string(menuBytes)
}
// 自定义端点验证
const (
maxCustomEndpoints = 10
maxEndpointNameLen = 50
maxEndpointURLLen = 2048
maxEndpointDescriptionLen = 200
)
customEndpointsJSON := previousSettings.CustomEndpoints
if req.CustomEndpoints != nil {
endpoints := *req.CustomEndpoints
if len(endpoints) > maxCustomEndpoints {
response.BadRequest(c, "Too many custom endpoints (max 10)")
return
}
for _, ep := range endpoints {
if strings.TrimSpace(ep.Name) == "" {
response.BadRequest(c, "Custom endpoint name is required")
return
}
if len(ep.Name) > maxEndpointNameLen {
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
return
}
if strings.TrimSpace(ep.Endpoint) == "" {
response.BadRequest(c, "Custom endpoint URL is required")
return
}
if len(ep.Endpoint) > maxEndpointURLLen {
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
return
}
if len(ep.Description) > maxEndpointDescriptionLen {
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
return
}
}
endpointBytes, err := json.Marshal(endpoints)
if err != nil {
response.BadRequest(c, "Failed to serialize custom endpoints")
return
}
customEndpointsJSON = string(endpointBytes)
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@ -495,6 +562,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
@ -592,6 +660,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
@ -828,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
@ -844,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
return
}
if req.SMTPPort <= 0 {
req.SMTPPort = 587
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
if req.SMTPHost == "" && savedConfig != nil {
req.SMTPHost = savedConfig.Host
}
if req.SMTPPort <= 0 {
if savedConfig != nil && savedConfig.Port > 0 {
req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
}
}
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{
Host: req.SMTPHost,
@ -877,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
@ -895,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return
}
if req.SMTPPort <= 0 {
req.SMTPPort = 587
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
if req.SMTPHost == "" && savedConfig != nil {
req.SMTPHost = savedConfig.Host
}
if req.SMTPPort <= 0 {
if savedConfig != nil && savedConfig.Port > 0 {
req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
}
}
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPFrom == "" && savedConfig != nil {
req.SMTPFrom = savedConfig.From
}
if req.SMTPFromName == "" && savedConfig != nil {
req.SMTPFromName = savedConfig.FromName
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{
Host: req.SMTPHost,

View File

@ -15,6 +15,13 @@ type CustomMenuItem struct {
SortOrder int `json:"sort_order"`
}
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
type CustomEndpoint struct {
Name string `json:"name"`
Endpoint string `json:"endpoint"`
Description string `json:"description"`
}
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
@ -56,6 +63,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@ -114,6 +122,7 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
}
return filtered
}
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
// Returns empty slice on empty/invalid input.
func ParseCustomEndpoints(raw string) []CustomEndpoint {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return []CustomEndpoint{}
}
var items []CustomEndpoint
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []CustomEndpoint{}
}
return items
}

View File

@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填
if reqModel == "" {
@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c)

View File

@ -0,0 +1,289 @@
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
// POST /v1/chat/completions
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleCCFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.cc.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -0,0 +1,295 @@
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
// POST /v1/responses
// This converts Responses API requests to Anthropic format, forwards to Anthropic
// upstream, and converts responses back to Responses format.
func (h *GatewayHandler) Responses(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream using gjson (like OpenAI handler)
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
// The existing service-layer checkClaudeCodeRestriction handles degradation
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.responsesErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// Can't failover if streaming content already sent
if c.Writer.Size() != writerSizeBeforeForward {
h.handleResponsesFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.responses.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// responsesErrorResponse writes an error in OpenAI Responses API format.
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return // Can't write error after stream started
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)

View File

@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)

View File

@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("previous_response_id_kind", previousResponseIDKind),
)
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
var currentUserRelease func()
var currentAccountRelease func()

View File

@ -27,6 +27,9 @@ const (
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts"
@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
}
}
// setOpsEndpointContext stores upstream model and request type for ops error logging.
// Called by handlers after model mapping and request type determination.
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
if c == nil {
return
}
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
c.Set(opsUpstreamModelKey, upstreamModel)
}
c.Set(opsRequestTypeKey, requestType)
}
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
}
return ""
}(),
Stream: stream,
Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: "upstream",
@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
}
return ""
}(),
Stream: stream,
Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,

View File

@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
})
}
}
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
v, ok := c.Get(opsUpstreamModelKey)
require.True(t, ok)
vStr, ok := v.(string)
require.True(t, ok)
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(2), rtVal)
}
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "", int16(1))
_, ok := c.Get(opsUpstreamModelKey)
require.False(t, ok, "empty upstream model should not be stored")
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(1), rtVal)
}
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
require.NotPanics(t, func() {
setOpsEndpointContext(nil, "model", int16(1))
})
}

View File

@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,

View File

@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {

View File

@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}
setOpsRequestContext(c, reqModel, clientStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {

View File

@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {

View File

@ -0,0 +1,521 @@
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: AnthropicResponse → ResponsesResponse
// ---------------------------------------------------------------------------
// AnthropicToResponsesResponse converts an Anthropic Messages response into a
// Responses API response. This is the reverse of ResponsesToAnthropic and
// enables Anthropic upstream responses to be returned in OpenAI Responses format.
func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
id := resp.ID
if id == "" {
id = generateResponsesID()
}
out := &ResponsesResponse{
ID: id,
Object: "response",
Model: resp.Model,
}
var outputs []ResponsesOutput
var msgParts []ResponsesContentPart
for _, block := range resp.Content {
switch block.Type {
case "thinking":
if block.Thinking != "" {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: block.Thinking,
}},
})
}
case "text":
if block.Text != "" {
msgParts = append(msgParts, ResponsesContentPart{
Type: "output_text",
Text: block.Text,
})
}
case "tool_use":
args := "{}"
if len(block.Input) > 0 {
args = string(block.Input)
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toResponsesCallID(block.ID),
Name: block.Name,
Arguments: args,
Status: "completed",
})
}
}
// Assemble message output item from text parts
if len(msgParts) > 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: msgParts,
Status: "completed",
})
}
if len(outputs) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
Status: "completed",
})
}
out.Output = outputs
// Map stop_reason → status
out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content)
if out.Status == "incomplete" {
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
// Usage
out.Usage = &ResponsesUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.CacheReadInputTokens > 0 {
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: resp.Usage.CacheReadInputTokens,
}
}
return out
}
// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status.
func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string {
switch stopReason {
case "max_tokens":
return "incomplete"
case "end_turn", "tool_use", "stop_sequence":
return "completed"
default:
return "completed"
}
}
// ---------------------------------------------------------------------------
// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
// AnthropicEventToResponsesState tracks state for converting a sequence of
// Anthropic SSE events into Responses SSE events.
type AnthropicEventToResponsesState struct {
ResponseID string
Model string
Created int64
SequenceNumber int
// CreatedSent tracks whether response.created has been emitted.
CreatedSent bool
// CompletedSent tracks whether the terminal event has been emitted.
CompletedSent bool
// Current output tracking
OutputIndex int
CurrentItemID string
CurrentItemType string // "message" | "function_call" | "reasoning"
// For message output: accumulate text parts
ContentIndex int
// For function_call: track per-output info
CurrentCallID string
CurrentName string
// Usage from message_delta
InputTokens int
OutputTokens int
CacheReadInputTokens int
}
// NewAnthropicEventToResponsesState returns an initialised stream state.
func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState {
return &AnthropicEventToResponsesState{
Created: time.Now().Unix(),
}
}
// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into
// zero or more Responses SSE events, updating state as it goes.
func AnthropicEventToResponsesEvents(
evt *AnthropicStreamEvent,
state *AnthropicEventToResponsesState,
) []ResponsesStreamEvent {
switch evt.Type {
case "message_start":
return anthToResHandleMessageStart(evt, state)
case "content_block_start":
return anthToResHandleContentBlockStart(evt, state)
case "content_block_delta":
return anthToResHandleContentBlockDelta(evt, state)
case "content_block_stop":
return anthToResHandleContentBlockStop(evt, state)
case "message_delta":
return anthToResHandleMessageDelta(evt, state)
case "message_stop":
return anthToResHandleMessageStop(state)
default:
return nil
}
}
// FinalizeAnthropicResponsesStream emits synthetic termination events if the
// stream ended without a proper message_stop.
func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if !state.CreatedSent || state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, "completed", nil))
state.CompletedSent = true
return events
}
// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line.
func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) {
data, err := json.Marshal(evt)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
}
// --- internal handlers ---
func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Message != nil {
state.ResponseID = evt.Message.ID
if state.Model == "" {
state.Model = evt.Message.Model
}
if evt.Message.Usage.InputTokens > 0 {
state.InputTokens = evt.Message.Usage.InputTokens
}
}
if state.CreatedSent {
return nil
}
state.CreatedSent = true
// Emit response.created
return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)}
}
func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.ContentBlock == nil {
return nil
}
var events []ResponsesStreamEvent
switch evt.ContentBlock.Type {
case "thinking":
state.CurrentItemID = generateItemID()
state.CurrentItemType = "reasoning"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "reasoning",
ID: state.CurrentItemID,
},
}))
case "text":
// If we don't have an open message item, open one
if state.CurrentItemType != "message" {
state.CurrentItemID = generateItemID()
state.CurrentItemType = "message"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "message",
ID: state.CurrentItemID,
Role: "assistant",
Status: "in_progress",
},
}))
}
case "tool_use":
// Close previous item if any
events = append(events, closeCurrentResponsesItem(state)...)
state.CurrentItemID = generateItemID()
state.CurrentItemType = "function_call"
state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID)
state.CurrentName = evt.ContentBlock.Name
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "function_call",
ID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
Status: "in_progress",
},
}))
}
return events
}
func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Delta == nil {
return nil
}
switch evt.Delta.Type {
case "text_delta":
if evt.Delta.Text == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
Delta: evt.Delta.Text,
ItemID: state.CurrentItemID,
})}
case "thinking_delta":
if evt.Delta.Thinking == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
Delta: evt.Delta.Thinking,
ItemID: state.CurrentItemID,
})}
case "input_json_delta":
if evt.Delta.PartialJSON == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Delta: evt.Delta.PartialJSON,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
})}
case "signature_delta":
// Anthropic signature deltas have no Responses equivalent; skip
return nil
}
return nil
}
func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
switch state.CurrentItemType {
case "reasoning":
// Emit reasoning summary done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
ItemID: state.CurrentItemID,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "function_call":
// Emit function_call_arguments.done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "message":
// Emit output_text.done (text block is done, but message item stays open for potential more blocks)
return []ResponsesStreamEvent{
makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
ItemID: state.CurrentItemID,
}),
}
}
return nil
}
func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
// Update usage
if evt.Usage != nil {
state.OutputTokens = evt.Usage.OutputTokens
if evt.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
}
}
return nil
}
func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Determine status
status := "completed"
var incompleteDetails *ResponsesIncompleteDetails
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails))
state.CompletedSent = true
return events
}
// --- helper functions ---
func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CurrentItemType == "" {
return nil
}
itemType := state.CurrentItemType
itemID := state.CurrentItemID
// Reset
state.CurrentItemType = ""
state.CurrentItemID = ""
state.CurrentCallID = ""
state.CurrentName = ""
state.OutputIndex++
state.ContentIndex = 0
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex - 1, // Use the index before increment
Item: &ResponsesOutput{
Type: itemType,
ID: itemID,
Status: "completed",
},
})}
}
func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
return ResponsesStreamEvent{
Type: "response.created",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: "in_progress",
Output: []ResponsesOutput{},
},
}
}
func makeResponsesCompletedEvent(
state *AnthropicEventToResponsesState,
status string,
incompleteDetails *ResponsesIncompleteDetails,
) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
usage := &ResponsesUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
TotalTokens: state.InputTokens + state.OutputTokens,
}
if state.CacheReadInputTokens > 0 {
usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: state.CacheReadInputTokens,
}
}
return ResponsesStreamEvent{
Type: "response.completed",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: status,
Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity
Usage: usage,
IncompleteDetails: incompleteDetails,
},
}
}
func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
evt := *template
evt.Type = eventType
evt.SequenceNumber = seq
return evt
}
func generateResponsesID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "resp_" + hex.EncodeToString(b)
}
func generateItemID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "item_" + hex.EncodeToString(b)
}

View File

@ -0,0 +1,464 @@
package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// ResponsesToAnthropicRequest converts a Responses API request into an
// Anthropic Messages request. This is the reverse of AnthropicToResponses and
// enables Anthropic platform groups to accept OpenAI Responses API requests
// by converting them to the native /v1/messages format before forwarding upstream.
func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) {
system, messages, err := convertResponsesInputToAnthropic(req.Input)
if err != nil {
return nil, err
}
out := &AnthropicRequest{
Model: req.Model,
Messages: messages,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
}
if len(system) > 0 {
out.System = system
}
// max_output_tokens → max_tokens
if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 {
out.MaxTokens = *req.MaxOutputTokens
}
if out.MaxTokens == 0 {
// Anthropic requires max_tokens; default to a sensible value.
out.MaxTokens = 8192
}
// Convert tools
if len(req.Tools) > 0 {
out.Tools = convertResponsesToAnthropicTools(req.Tools)
}
// Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses)
if len(req.ToolChoice) > 0 {
tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice)
if err != nil {
return nil, fmt.Errorf("convert tool_choice: %w", err)
}
out.ToolChoice = tc
}
// reasoning.effort → output_config.effort + thinking
if req.Reasoning != nil && req.Reasoning.Effort != "" {
effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort)
out.OutputConfig = &AnthropicOutputConfig{Effort: effort}
// Enable thinking for non-low efforts
if effort != "low" {
out.Thinking = &AnthropicThinking{
Type: "enabled",
BudgetTokens: defaultThinkingBudget(effort),
}
}
}
return out, nil
}
// defaultThinkingBudget returns a sensible thinking budget based on effort level.
func defaultThinkingBudget(effort string) int {
switch effort {
case "low":
return 1024
case "medium":
return 4096
case "high":
return 10240
case "max":
return 32768
default:
return 10240
}
}
// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to
// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses.
//
// low → low
// medium → medium
// high → high
// xhigh → max
func mapResponsesEffortToAnthropic(effort string) string {
if effort == "xhigh" {
return "max"
}
return effort // low→low, medium→medium, high→high, unknown→passthrough
}
// convertResponsesInputToAnthropic extracts system prompt and messages from
// a Responses API input array. Returns the system as raw JSON (for Anthropic's
// polymorphic system field) and a list of Anthropic messages.
func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) {
// Try as plain string input.
var inputStr string
if err := json.Unmarshal(inputRaw, &inputStr); err == nil {
content, _ := json.Marshal(inputStr)
return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil
}
var items []ResponsesInputItem
if err := json.Unmarshal(inputRaw, &items); err != nil {
return nil, nil, fmt.Errorf("parse responses input: %w", err)
}
var system json.RawMessage
var messages []AnthropicMessage
for _, item := range items {
switch {
case item.Role == "system":
// System prompt → Anthropic system field
text := extractTextFromContent(item.Content)
if text != "" {
system, _ = json.Marshal(text)
}
case item.Type == "function_call":
// function_call → assistant message with tool_use block
input := json.RawMessage("{}")
if item.Arguments != "" {
input = json.RawMessage(item.Arguments)
}
block := AnthropicContentBlock{
Type: "tool_use",
ID: fromResponsesCallIDToAnthropic(item.CallID),
Name: item.Name,
Input: input,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: blockJSON,
})
case item.Type == "function_call_output":
// function_call_output → user message with tool_result block
outputContent := item.Output
if outputContent == "" {
outputContent = "(empty)"
}
contentJSON, _ := json.Marshal(outputContent)
block := AnthropicContentBlock{
Type: "tool_result",
ToolUseID: fromResponsesCallIDToAnthropic(item.CallID),
Content: contentJSON,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "user",
Content: blockJSON,
})
case item.Role == "user":
content, err := convertResponsesUserToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "user",
Content: content,
})
case item.Role == "assistant":
content, err := convertResponsesAssistantToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: content,
})
default:
// Unknown role/type — attempt as user message
if item.Content != nil {
messages = append(messages, AnthropicMessage{
Role: "user",
Content: item.Content,
})
}
}
}
// Merge consecutive same-role messages (Anthropic requires alternating roles)
messages = mergeConsecutiveMessages(messages)
return system, messages, nil
}
// extractTextFromContent extracts text from a content field that may be a
// plain string or an array of content parts.
func extractTextFromContent(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
var texts []string
for _, p := range parts {
if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" {
texts = append(texts, p.Text)
}
}
return strings.Join(texts, "\n\n")
}
return ""
}
// convertResponsesUserToAnthropicContent converts a Responses user message
// content field into Anthropic content blocks JSON.
func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal("") // empty string content
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
// Pass through as-is if we can't parse
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "input_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
case "input_image":
src := dataURIToAnthropicImageSource(p.ImageURL)
if src != nil {
blocks = append(blocks, AnthropicContentBlock{
Type: "image",
Source: src,
})
}
}
}
if len(blocks) == 0 {
return json.Marshal("")
}
return json.Marshal(blocks)
}
// convertResponsesAssistantToAnthropicContent converts a Responses assistant
// message content field into Anthropic content blocks JSON.
func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}})
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}})
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "output_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
}
}
if len(blocks) == 0 {
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
}
return json.Marshal(blocks)
}
// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to
// Anthropic format. Reverses toResponsesCallID.
func fromResponsesCallIDToAnthropic(id string) string {
// If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it
if after, ok := strings.CutPrefix(id, "fc_"); ok {
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
return after
}
}
// Generate a synthetic Anthropic tool ID
if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") {
return "toolu_" + id
}
return id
}
// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource.
func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource {
if !strings.HasPrefix(dataURI, "data:") {
return nil
}
// Format: data:<media_type>;base64,<data>
rest := strings.TrimPrefix(dataURI, "data:")
semicolonIdx := strings.Index(rest, ";")
if semicolonIdx < 0 {
return nil
}
mediaType := rest[:semicolonIdx]
rest = rest[semicolonIdx+1:]
if !strings.HasPrefix(rest, "base64,") {
return nil
}
data := strings.TrimPrefix(rest, "base64,")
return &AnthropicImageSource{
Type: "base64",
MediaType: mediaType,
Data: data,
}
}
// mergeConsecutiveMessages merges consecutive messages with the same role
// because Anthropic requires alternating user/assistant turns.
func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage {
if len(messages) <= 1 {
return messages
}
var merged []AnthropicMessage
for _, msg := range messages {
if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role {
merged = append(merged, msg)
continue
}
// Same role — merge content arrays
last := &merged[len(merged)-1]
lastBlocks := parseContentBlocks(last.Content)
newBlocks := parseContentBlocks(msg.Content)
combined := append(lastBlocks, newBlocks...)
last.Content, _ = json.Marshal(combined)
}
return merged
}
// parseContentBlocks attempts to parse content as []AnthropicContentBlock.
// If it's a string, wraps it in a text block.
func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock {
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err == nil {
return blocks
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return []AnthropicContentBlock{{Type: "text", Text: s}}
}
return nil
}
// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format.
// Reverse of convertAnthropicToolsToResponses.
func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
})
case "function":
out = append(out, AnthropicTool{
Name: t.Name,
Description: t.Description,
InputSchema: normalizeAnthropicInputSchema(t.Parameters),
})
default:
// Pass through unknown tool types
out = append(out, AnthropicTool{
Type: t.Type,
Name: t.Name,
Description: t.Description,
InputSchema: t.Parameters,
})
}
}
return out
}
// normalizeAnthropicInputSchema ensures the input_schema has a "type" field.
func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
if len(schema) == 0 || string(schema) == "null" {
return json.RawMessage(`{"type":"object","properties":{}}`)
}
return schema
}
// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format.
// Reverse of convertAnthropicToolChoiceToResponses.
//
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
if err := json.Unmarshal(raw, &s); err == nil {
switch s {
case "auto":
return json.Marshal(map[string]string{"type": "auto"})
case "required":
return json.Marshal(map[string]string{"type": "any"})
case "none":
return json.Marshal(map[string]string{"type": "none"})
default:
return raw, nil
}
}
// Try as object with type=function
var tc struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
return json.Marshal(map[string]string{
"type": "tool",
"name": tc.Function.Name,
})
}
// Pass through unknown
return raw, nil
}

View File

@ -270,6 +270,7 @@ type OpenAIAuthClaims struct {
ChatGPTUserID string `json:"chatgpt_user_id"`
ChatGPTPlanType string `json:"chatgpt_plan_type"`
UserID string `json:"user_id"`
POID string `json:"poid"` // organization ID in access_token JWT
Organizations []OrganizationClaim `json:"organizations"`
}

View File

@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return nil
}
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
_, err := r.client.Account.UpdateOneID(id).
SetCredentials(normalizeJSONMap(credentials)).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil {
@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
return r.ListWithFilters(ctx, params, "", "", "", "", 0, "")
}
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query()
if platform != "" {
@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
} else if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
}
if privacyMode != "" {
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
path := sqljson.Path("privacy_mode")
switch privacyMode {
case service.AccountPrivacyModeUnsetFilter:
s.Where(entsql.Or(
entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)),
sqljson.ValueEQ(dbaccount.FieldExtra, "", path),
))
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path))
}
}))
}
total, err := q.Count(ctx)
if err != nil {

View File

@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(client *dbent.Client)
platform string
accType string
status string
search string
groupID int64
wantCount int
validate func(accounts []service.Account)
name string
setup func(client *dbent.Client)
platform string
accType string
status string
search string
groupID int64
privacyMode string
wantCount int
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Empty(accounts[0].GroupIDs)
},
},
{
name: "filter_by_privacy_mode",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}})
},
privacyMode: service.PrivacyModeTrainingOff,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("privacy-ok", accounts[0].Name)
},
},
{
name: "filter_by_privacy_mode_unset",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
},
privacyMode: service.AccountPrivacyModeUnsetFilter,
wantCount: 2,
validate: func(accounts []service.Account) {
names := []string{accounts[0].Name, accounts[1].Name}
s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names)
},
},
}
for _, tt := range tests {
@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)

View File

@ -29,6 +29,11 @@ INSERT INTO ops_error_logs (
model,
request_path,
stream,
inbound_endpoint,
upstream_endpoint,
requested_model,
upstream_model,
request_type,
user_agent,
error_phase,
error_type,
@ -57,7 +62,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43
)`
func NewOpsRepository(db *sql.DB) service.OpsRepository {
@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
opsNullString(input.Model),
opsNullString(input.RequestPath),
input.Stream,
opsNullString(input.InboundEndpoint),
opsNullString(input.UpstreamEndpoint),
opsNullString(input.RequestedModel),
opsNullString(input.UpstreamModel),
opsNullInt16(input.RequestType),
opsNullString(input.UserAgent),
input.ErrorPhase,
input.ErrorType,
@ -231,7 +241,12 @@ SELECT
COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream
e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type
FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id
@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
var resolvedBy sql.NullInt64
var resolvedByName string
var resolvedRetryID sql.NullInt64
var requestType sql.NullInt64
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
&clientIP,
&item.RequestPath,
&item.Stream,
&item.InboundEndpoint,
&item.UpstreamEndpoint,
&item.RequestedModel,
&item.UpstreamModel,
&requestType,
); err != nil {
return nil, err
}
@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
item.GroupID = &v
}
item.GroupName = groupName
if requestType.Valid {
v := int16(requestType.Int64)
item.RequestType = &v
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
@ -393,6 +418,11 @@ SELECT
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''),
e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type,
COALESCE(e.user_agent, ''),
e.auth_latency_ms,
e.routing_latency_ms,
@ -427,6 +457,7 @@ LIMIT 1`
var responseLatency sql.NullInt64
var ttft sql.NullInt64
var requestBodyBytes sql.NullInt64
var requestType sql.NullInt64
err := r.db.QueryRowContext(ctx, q, id).Scan(
&out.ID,
@ -464,6 +495,11 @@ LIMIT 1`
&clientIP,
&out.RequestPath,
&out.Stream,
&out.InboundEndpoint,
&out.UpstreamEndpoint,
&out.RequestedModel,
&out.UpstreamModel,
&requestType,
&out.UserAgent,
&authLatency,
&routingLatency,
@ -540,6 +576,10 @@ LIMIT 1`
v := int(requestBodyBytes.Int64)
out.RequestBodyBytes = &v
}
if requestType.Valid {
v := int16(requestType.Int64)
out.RequestType = &v
}
// Normalize request_body to empty string when stored as JSON null.
out.RequestBody = strings.TrimSpace(out.RequestBody)
@ -1479,3 +1519,10 @@ func opsNullInt(v any) any {
return sql.NullInt64{}
}
}
func opsNullInt16(v *int16) any {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: int64(*v), Valid: true}
}

View File

@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) {
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": []
"custom_menu_items": [],
"custom_endpoints": []
}
}`,
},
@ -989,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@ -69,12 +69,30 @@ func RegisterGatewayRoutes(
})
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
// OpenAI Responses API: auto-route based on group platform
gateway.POST("/responses", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.POST("/responses/*subpath", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
// OpenAI Chat Completions API: auto-route based on group platform
gateway.POST("/chat/completions", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
@ -92,12 +110,25 @@ func RegisterGatewayRoutes(
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
// OpenAI Responses API不带v1前缀的别名— auto-route based on group platform
responsesHandler := func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
}
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// OpenAI Chat Completions API不带v1前缀的别名— auto-route based on group platform
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)

View File

@ -0,0 +1,30 @@
package service
import "context"
type accountCredentialsUpdater interface {
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
}
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
if repo == nil || account == nil {
return nil
}
account.Credentials = cloneCredentials(credentials)
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
}
return repo.Update(ctx, account)
}
func cloneCredentials(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@ -15,6 +15,7 @@ var (
)
const AccountListGroupUngrouped int64 = -1
const AccountPrivacyModeUnsetFilter = "__unset__"
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
@ -37,7 +38,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)

View File

@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call")
}
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}

View File

@ -54,7 +54,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil {
return nil, 0, err
}

View File

@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct {
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersPrivacy string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersPrivacy = privacyMode
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
})
}
func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{

View File

@ -643,6 +643,7 @@ urlFallbackLoop:
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error",
Message: safeErr,
})
@ -720,6 +721,7 @@ urlFallbackLoop:
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
@ -754,6 +756,7 @@ urlFallbackLoop:
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),

View File

@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID,
"error", updateErr,

View File

@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
}
}
item.Action = "created"
@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
}
}
@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
}
item.Action = "created"
result.Created++
@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
}
item.Action = "updated"
@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
}
item.Action = "created"
result.Created++
@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
}
item.Action = "updated"

View File

@ -119,6 +119,7 @@ const (
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组)
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@ -12,6 +12,7 @@ import (
"net/smtp"
"net/url"
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySMTPHost]
host := strings.TrimSpace(settings[SettingKeySMTPHost])
if host == "" {
return nil, ErrEmailNotConfigured
}
@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return &SMTPConfig{
Host: host,
Port: port,
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
Username: strings.TrimSpace(settings[SettingKeySMTPUsername]),
Password: strings.TrimSpace(settings[SettingKeySMTPPassword]),
From: strings.TrimSpace(settings[SettingKeySMTPFrom]),
FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]),
UseTLS: useTLS,
}, nil
}

View File

@ -0,0 +1,485 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body,
// converts it to Anthropic Messages format (chained via Responses format),
// forwards to the Anthropic upstream, and converts the response back to Chat
// Completions format. This enables Chat Completions clients to access Anthropic
// models through Anthropic platform groups.
func (s *GatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var ccReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &ccReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := ccReq.Model
clientStream := ccReq.Stream
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
// 2. Convert CC → Responses → Anthropic (chained conversion)
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Extract reasoning effort from CC request body
reasoningEffort := extractCCReasoningEffortFromBody(body)
// 14. Handle normal response
// Read Anthropic SSE → convert to Responses events → convert to CC format
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage)
} else {
result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions
// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort)
// formats used by OpenAI-compatible clients.
func extractCCReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full
// response, then converts Anthropic → Responses → Chat Completions.
func (s *GatewayService) handleCCBufferedFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
// message_start carries the initial response structure and cache usage
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Chain: Anthropic → Responses → Chat Completions
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, ccResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each
// to Responses events, then to Chat Completions chunks, and writes them.
func (s *GatewayService) handleCCStreamingFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
includeUsage bool,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
// Use Anthropic→Responses state machine, then convert Responses→CC
anthState := apicompat.NewAnthropicEventToResponsesState()
anthState.Model = originalModel
ccState := apicompat.NewResponsesEventToChatState()
ccState.Model = originalModel
ccState.IncludeUsage = includeUsage
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
return true // client disconnected
}
return false
}
processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start (carries cache fields)
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Chain: Anthropic event → Responses events → CC chunks
responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState)
for _, resEvt := range responsesEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
if disconnected := writeChunk(chunk); disconnected {
return true
}
}
}
c.Writer.Flush()
return false
}
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
if processAnthropicEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Finalize both state machines
finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState)
for _, resEvt := range finalResEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
writeChunk(chunk) //nolint:errcheck
}
}
finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState)
for _, chunk := range finalCCChunks {
writeChunk(chunk) //nolint:errcheck
}
// Write [DONE] marker
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
// writeGatewayCCError writes an error in OpenAI Chat Completions format for
// the Anthropic-upstream CC forwarding path.
func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@ -0,0 +1,109 @@
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractCCReasoningEffortFromBody(t *testing.T) {
t.Parallel()
t.Run("nested reasoning.effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
})
t.Run("flat reasoning_effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`))
require.NotNil(t, got)
require.Equal(t, "xhigh", *got)
})
t.Run("missing effort", func(t *testing.T) {
require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`)))
})
}
func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "high"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "high", *result.ReasoningEffort)
}
func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "medium"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "medium", *result.ReasoningEffort)
require.Contains(t, rec.Body.String(), `[DONE]`)
}

View File

@ -0,0 +1,518 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsResponses accepts an OpenAI Responses API request body, converts it
// to Anthropic Messages format, forwards to the Anthropic upstream, and converts
// the response back to Responses format. This enables OpenAI Responses API
// clients to access Anthropic models through Anthropic platform groups.
//
// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic
// but in reverse direction: Responses → Anthropic upstream → Responses.
func (s *GatewayService) ForwardAsResponses(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Responses request
var responsesReq apicompat.ResponsesRequest
if err := json.Unmarshal(body, &responsesReq); err != nil {
return nil, fmt.Errorf("parse responses request: %w", err)
}
originalModel := responsesReq.Model
clientStream := responsesReq.Stream
// 2. Convert Responses → Anthropic
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming (Anthropic works best with streaming)
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_responses: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
// Non-failover error: return Responses-formatted error to client
writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Handle normal response (convert Anthropic → Responses)
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
} else {
result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort
// and normalizes it for usage logging.
func ExtractResponsesReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) {
if dst == nil {
return
}
if src.InputTokens > 0 {
dst.InputTokens = src.InputTokens
}
if src.OutputTokens > 0 {
dst.OutputTokens = src.OutputTokens
}
if src.CacheReadInputTokens > 0 {
dst.CacheReadInputTokens = src.CacheReadInputTokens
}
if src.CacheCreationInputTokens > 0 {
dst.CacheCreationInputTokens = src.CacheCreationInputTokens
}
}
// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from
// the upstream streaming response, assembles them into a complete Anthropic
// response, converts to Responses API JSON format, and writes it to the client.
func (s *GatewayService) handleResponsesBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// Accumulate the final Anthropic response from streaming events
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read the data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
// message_start carries the initial response structure
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
// Accumulate content blocks
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Convert to Responses format
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
responsesResp.Model = originalModel // Use original model name
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleResponsesStreamingResponse reads Anthropic SSE events from upstream,
// converts each to Responses SSE events, and writes them to the client.
func (s *GatewayService) handleResponsesStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewAnthropicEventToResponsesState()
state.Model = originalModel
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
// processEvent handles a single parsed Anthropic SSE event.
processEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Convert to Responses events
events := apicompat.AnthropicEventToResponsesEvents(event, state)
for _, evt := range events {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
logger.L().Warn("forward_as_responses stream: failed to marshal event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
return true // client disconnected
}
}
if len(events) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*ForwardResult, error) {
if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// Read Anthropic SSE events
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
if processEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
return finalizeStream()
}
// appendRawJSON appends a JSON fragment string to existing raw JSON.
func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage {
if len(existing) == 0 {
return json.RawMessage(fragment)
}
return json.RawMessage(string(existing) + fragment)
}
// writeResponsesError writes an error response in OpenAI Responses API format.
func writeResponsesError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes.
func mapUpstreamStatusCode(code int) int {
if code >= 500 {
return http.StatusBadGateway
}
return code
}

View File

@ -0,0 +1,94 @@
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractResponsesReasoningEffortFromBody(t *testing.T) {
t.Parallel()
got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`)))
}
func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `"cached_tokens":9`)
}
func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `response.completed`)
}

View File

@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@ -5,6 +5,8 @@ import (
"encoding/json"
"fmt"
"math"
"regexp"
"sort"
"strings"
"unsafe"
@ -34,6 +36,9 @@ var (
patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`)
sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`)
sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`)
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
@ -75,6 +80,49 @@ type ParsedRequest struct {
OnUpstreamAccepted func()
}
// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing.
// It preserves the set of product names from Product/Version tokens while
// discarding version-only changes and incidental comments.
func NormalizeSessionUserAgent(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1)
if len(matches) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
products := make([]string, 0, len(matches))
seen := make(map[string]struct{}, len(matches))
for _, match := range matches {
if len(match) < 2 {
continue
}
product := strings.ToLower(strings.TrimSpace(match[1]))
if product == "" {
continue
}
if _, exists := seen[product]; exists {
continue
}
seen[product] = struct{}{}
products = append(products, product)
}
if len(products) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
sort.Strings(products)
return strings.Join(products, "+")
}
func normalizeSessionUserAgentFallback(raw string) string {
normalized := strings.ToLower(strings.Join(strings.Fields(raw), " "))
normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "")
return strings.Join(strings.Fields(normalized), " ")
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte {
return []byte(r.Raw)
}
// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content).
// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged.
func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) {
var result []any
changed := false
for i, block := range blocks {
blockMap, ok := block.(map[string]any)
if !ok {
if result != nil {
result = append(result, block)
}
continue
}
blockType, _ := blockMap["type"].(string)
// Strip empty text blocks
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
continue
}
}
// Recurse into tool_result nested content
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
result = append(result, blockCopy)
continue
}
}
}
if result != nil {
result = append(result, block)
}
}
if !changed {
return blocks, false
}
return result, true
}
// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content).
// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors.
// Returns the original body unchanged if no empty text blocks are found.
func StripEmptyTextBlocks(body []byte) []byte {
// Fast path: check if body contains empty text patterns
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
if !hasEmptyTextBlock {
return body
}
jsonStr := *(*string)(unsafe.Pointer(&body))
msgsRes := gjson.Get(jsonStr, "messages")
if !msgsRes.Exists() || !msgsRes.IsArray() {
return body
}
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil {
return body
}
modified := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed {
modified = true
msgMap["content"] = cleaned
}
}
if !modified {
return body
}
msgsBytes, err := json.Marshal(messages)
if err != nil {
return body
}
out, err := sjson.SetRawBytes(body, "messages", msgsBytes)
if err != nil {
return body
}
return out
}
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
}
}
// Recursively strip empty text blocks from tool_result nested content.
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed {
modifiedThisMsg = true
ensureNewContent(bi)
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
newContent = append(newContent, blockCopy)
continue
}
}
}
if newContent != nil {
newContent = append(newContent, block)
}

View File

@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
require.NotEmpty(t, block1["text"])
}
func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) {
// Empty text blocks nested inside tool_result content should also be stripped
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":"valid result"},
{"type":"text","text":""}
]}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
toolResult := content0[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"])
}
func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) {
// If all nested content blocks in tool_result are empty text, content becomes empty slice
input := []byte(`{
"messages":[
{"role":"user","content":[
{"type":"tool_result","tool_use_id":"t1","content":[
{"type":"text","text":""}
]},
{"type":"text","text":"hello"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 2)
toolResult := content0[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 0)
}
func TestStripEmptyTextBlocks(t *testing.T) {
t.Run("strips top-level empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
require.Equal(t, "hello", content[0].(map[string]any)["text"])
})
t.Run("strips nested empty text in tool_result", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
toolResult := content[0].(map[string]any)
nestedContent := toolResult["content"].([]any)
require.Len(t, nestedContent, 1)
require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"])
})
t.Run("no-op when no empty text", func(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := StripEmptyTextBlocks(input)
require.Equal(t, input, out)
})
t.Run("preserves non-map blocks in content", func(t *testing.T) {
// tool_result content can be a string; non-map blocks should pass through unchanged
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
require.Len(t, content, 1)
toolResult := content[0].(map[string]any)
require.Equal(t, "tool_result", toolResult["type"])
require.Equal(t, "string content", toolResult["content"])
})
t.Run("handles deeply nested tool_result", func(t *testing.T) {
// Recursive: tool_result containing another tool_result with empty text
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`)
out := StripEmptyTextBlocks(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs := req["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
outer := content[0].(map[string]any)
innerContent := outer["content"].([]any)
inner := innerContent[0].(map[string]any)
deepContent := inner["content"].([]any)
require.Len(t, deepContent, 1)
require.Equal(t, "deep", deepContent[0].(map[string]any)["text"])
})
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged
input := []byte(`{

View File

@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent))
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 调试日志:记录即将转发的账号信息
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body)
@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error",
Message: safeErr,
})
@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "signature_error",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(retryReq.URL.String()),
Kind: "signature_retry_thinking",
Message: extractUpstreamErrorMessage(retryRespBody),
Detail: func() string {
@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(retryReq2.URL.String()),
Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
})
@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "budget_constraint_error",
Message: errMsg,
Detail: func() string {
@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
if c != nil {
c.Set("anthropic_passthrough", true)
}
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
input.Body = StripEmptyTextBlocks(input.Body)
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, input.Body)
@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true,
Kind: "request_error",
Message: safeErr,
@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true,
Kind: "retry",
Message: extractUpstreamErrorMessage(respBody),
@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error",
Message: safeErr,
})
@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream(
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true,
Kind: "request_error",
Message: sanitizeUpstreamErrorMessage(err.Error()),
@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Passthrough: true,
Kind: "http_error",
Message: upstreamMsg,

View File

@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
// 返回 16 字符的 Base64 编码的 SHA256 前缀
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
// 组合所有标识符
normalizedUserAgent := NormalizeSessionUserAgent(userAgent)
combined := strconv.FormatInt(userID, 10) + ":" +
strconv.FormatInt(apiKeyID, 10) + ":" +
ip + ":" +
userAgent + ":" +
normalizedUserAgent + ":" +
platform + ":" +
model

View File

@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) {
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro")
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro")
if hash1 != hash2 {
t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string

View File

@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account)
_ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
}
}

View File

@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
}
func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0"))
h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1"))
require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) {
svc := &GatewayService{}
base := func(ua string) *ParsedRequest {
return &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "test"},
},
SessionContext: &SessionContext{
ClientIP: "1.1.1.1",
UserAgent: ua,
APIKeyID: 1,
},
}
}
h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0"))
h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1"))
require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash")
}
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
svc := &GatewayService{}

View File

@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
// 5. 设置版本号 + 更新 DB
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
freshAccount.Credentials = newCredentials
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID,
"error", updateErr,

View File

@ -16,10 +16,11 @@ import (
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct {
mockAccountRepoForGemini
account *Account // returned by GetByID
getByIDErr error
updateErr error
updateCalls int
account *Account // returned by GetByID
getByIDErr error
updateErr error
updateCalls int
updateCredentialsCalls int
}
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
return r.updateErr
}
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
if r.account == nil || r.account.ID != id {
r.account = &Account{ID: id}
}
r.account.Credentials = cloneCredentials(credentials)
return nil
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct {
needsRefresh bool
@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls)
}
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
resetAt := time.Now().Add(45 * time.Minute)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "safe-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotNil(t, repo.account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
}
@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
}
// ========== BuildClaudeAccountCredentials tests ==========

View File

@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr

View File

@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(33002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10104)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(34002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)

View File

@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
if fresh == nil {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh
}
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
return account
}
latest, err := s.accountRepo.GetByID(ctx, account.ID)
if err != nil || latest == nil {
return nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
if !latest.IsSchedulable() || !latest.IsOpenAI() {
return nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil
}
return latest
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
var (
account *Account
@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,

View File

@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require.True(t, arr[len(arr)-1].Passthrough)
}
func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-rate-limit"},
},
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))),
}
upstream := &httpUpstreamRecorder{resp: resp}
repo := &openAIWSRateLimitSignalRepo{}
rateSvc := &RateLimitService{accountRepo: repo}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
rateLimitService: rateSvc,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
require.Contains(t, rec.Body.String(), "usage_limit_reached")
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
}
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -29,9 +29,10 @@ type soraSessionChunk struct {
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
proxyRepo ProxyRepository
oauthClient OpenAIOAuthClient
sessionStore *openai.SessionStore
proxyRepo ProxyRepository
oauthClient OpenAIOAuthClient
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-apiImpersonateChrome
}
// NewOpenAIOAuthService creates a new OpenAI OAuth service
@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli
}
}
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
// 用于调用 chatgpt.com/backend-api 获取账号信息plan_type 等)。
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
s.privacyClientFactory = factory
}
// OpenAIAuthURLResult contains the authorization URL and session info
type OpenAIAuthURLResult struct {
AuthURL string `json:"auth_url"`
@ -131,6 +138,7 @@ type OpenAITokenInfo struct {
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.PlanType = userInfo.PlanType
}
// id_token 中缺少 plan_type 时(如 Mobile RT尝试通过 ChatGPT backend-api 补全
if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
// 从 access_token JWT 中提取 orgIDpoid用于匹配正确的账号
orgID := tokenInfo.OrganizationID
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
orgID = atClaims.OpenAIAuth.POID
}
}
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
if tokenInfo.PlanType == "" && info.PlanType != "" {
tokenInfo.PlanType = info.PlanType
}
if tokenInfo.Email == "" && info.Email != "" {
tokenInfo.Email = info.Email
}
}
}
// 尝试设置隐私关闭训练数据共享best-effort
if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
return tokenInfo, nil
}

View File

@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
return PrivacyModeTrainingOff
}
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
type ChatGPTAccountInfo struct {
PlanType string
Email string
}
const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
// Returns nil on any failure (best-effort, non-blocking).
func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo {
if accessToken == "" || clientFactory == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
client, err := clientFactory(proxyURL)
if err != nil {
slog.Debug("chatgpt_account_check_client_error", "error", err.Error())
return nil
}
var result map[string]any
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Origin", "https://chatgpt.com").
SetHeader("Referer", "https://chatgpt.com/").
SetHeader("Accept", "application/json").
SetSuccessResult(&result).
Get(chatGPTAccountsCheckURL)
if err != nil {
slog.Debug("chatgpt_account_check_request_error", "error", err.Error())
return nil
}
if !resp.IsSuccessState() {
slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
return nil
}
info := &ChatGPTAccountInfo{}
accounts, ok := result["accounts"].(map[string]any)
if !ok {
slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300))
return nil
}
// 优先匹配 orgID 对应的账号access_token JWT 中的 poid
if orgID != "" {
if matched := extractPlanFromAccount(accounts, orgID); matched != "" {
info.PlanType = matched
}
}
// 未匹配到时,遍历所有账号:优先 is_default次选非 free
if info.PlanType == "" {
var defaultPlan, paidPlan, anyPlan string
for _, acctRaw := range accounts {
acct, ok := acctRaw.(map[string]any)
if !ok {
continue
}
planType := extractPlanType(acct)
if planType == "" {
continue
}
if anyPlan == "" {
anyPlan = planType
}
if account, ok := acct["account"].(map[string]any); ok {
if isDefault, _ := account["is_default"].(bool); isDefault {
defaultPlan = planType
}
}
if !strings.EqualFold(planType, "free") && paidPlan == "" {
paidPlan = planType
}
}
// 优先级default > 非 free > 任意
switch {
case defaultPlan != "":
info.PlanType = defaultPlan
case paidPlan != "":
info.PlanType = paidPlan
default:
info.PlanType = anyPlan
}
}
if info.PlanType == "" {
slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300))
return nil
}
slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID)
return info
}
// extractPlanFromAccount 从 accounts map 中按 keyaccount_id精确匹配并提取 plan_type
func extractPlanFromAccount(accounts map[string]any, accountKey string) string {
acctRaw, ok := accounts[accountKey]
if !ok {
return ""
}
acct, ok := acctRaw.(map[string]any)
if !ok {
return ""
}
return extractPlanType(acct)
}
// extractPlanType 从单个 account 对象中提取 plan_type
func extractPlanType(acct map[string]any) string {
if account, ok := acct["account"].(map[string]any); ok {
if planType, ok := account["plan_type"].(string); ok && planType != "" {
return planType
}
}
if entitlement, ok := acct["entitlement"].(map[string]any); ok {
if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" {
return subPlan
}
}
return ""
}
func truncate(s string, n int) string {
if len(s) <= n {
return s

View File

@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(24)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleAccount := &Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
dbAccount := Account{
ID: 13,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
RateLimitResetAt: &rateLimitedUntil,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
snapshotCache := &openAISnapshotCacheStub{
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
require.NoError(t, getErr)
require.Zero(t, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background()
groupID := int64(23)

View File

@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {

View File

@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
return nil
}
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
_ = platform
_ = accountType
_ = status
_ = search
_ = groupID
_ = privacyMode
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
}
@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Len(t, accounts, 1)

View File

@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "", 0)
}, platformFilter, "", "", "", 0, "")
if err != nil {
return nil, err
}

View File

@ -62,6 +62,12 @@ type OpsErrorLog struct {
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
Stream bool `json:"stream"`
InboundEndpoint string `json:"inbound_endpoint"`
UpstreamEndpoint string `json:"upstream_endpoint"`
RequestedModel string `json:"requested_model"`
UpstreamModel string `json:"upstream_model"`
RequestType *int16 `json:"request_type"`
}
type OpsErrorLogDetail struct {

View File

@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
Model string
RequestPath string
Stream bool
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint string
// RequestedModel is the client-requested model name before mapping.
RequestedModel string
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
UpstreamModel string
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
// Matches service.RequestType enum semantics from usage_log.go.
RequestType *int16
UserAgent string
ErrorPhase string

View File

@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" {
@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
}
return out, nil
}
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
func safeUpstreamURL(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
rawURL = rawURL[:idx]
}
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
rawURL = rawURL[:idx]
}
return rawURL
}

View File

@ -8,6 +8,27 @@ import (
"github.com/stretchr/testify/require"
)
func TestSafeUpstreamURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
{"no query or fragment", "https://host/path", "https://host/path"},
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
})
}
}
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
account.Credentials = make(map[string]any)
}
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil {
if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)

View File

@ -15,9 +15,11 @@ import (
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
setErrorCalls int
tempCalls int
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
return nil
}
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1)
}
@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}

View File

@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) {
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {

View File

@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
}
@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil
@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"`
@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version,
@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
if raw == "" {
return json.RawMessage("[]")
}
if json.Valid([]byte(raw)) {
return json.RawMessage(raw)
}
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}

View File

@ -43,6 +43,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
DefaultConcurrency int
DefaultBalance float64
@ -104,6 +105,7 @@ type PublicSettings struct {
PurchaseSubscriptionURL string
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
BackendModeEnabled bool

View File

@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}

View File

@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}

View File

@ -14,19 +14,40 @@ import (
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
updateCalls int
fullUpdateCalls int
updateCredentialsCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.fullUpdateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCalls++
r.updateCredentialsCalls++
if r.updateErr != nil {
return r.updateErr
}
cloned := cloneCredentials(credentials)
if r.accountsByID != nil {
if acc, ok := r.accountsByID[id]; ok && acc != nil {
acc.Credentials = cloned
r.lastAccount = acc
return nil
}
}
r.lastAccount = &Account{ID: id, Credentials: cloned}
return nil
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
}
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
resetAt := time.Now().Add(30 * time.Minute)
account := &Account{
ID: 17,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
RateLimitResetAt: &resetAt,
Credentials: map[string]any{
"access_token": "old-token",
},
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Equal(t, 0, repo.fullUpdateCalls)
require.NotNil(t, account.RateLimitResetAt)
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
}

View File

@ -0,0 +1,28 @@
-- Ops error logs: add endpoint, model mapping, and request_type fields
-- to match usage_logs observability coverage.
--
-- All columns are nullable with no default to preserve backward compatibility
-- with existing rows.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256),
ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256);
-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100),
ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2)
ALTER TABLE ops_error_logs
ADD COLUMN IF NOT EXISTS request_type SMALLINT;
COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.';
COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.';
COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).';
COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.';
COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.';

View File

@ -36,6 +36,7 @@ export async function list(
status?: string
group?: string
search?: string
privacy_mode?: string
lite?: string
},
options?: {
@ -68,6 +69,7 @@ export async function listWithEtag(
status?: string
group?: string
search?: string
privacy_mode?: string
lite?: string
},
options?: {
@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
export async function refreshOpenAIToken(
refreshToken: string,
proxyId?: number | null,
endpoint: string = '/admin/openai/refresh-token'
endpoint: string = '/admin/openai/refresh-token',
clientId?: string
): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = {
const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = {
refresh_token: refreshToken
}
if (proxyId) {
payload.proxy_id = proxyId
}
if (clientId) {
payload.client_id = clientId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data
}

View File

@ -969,6 +969,13 @@ export interface OpsErrorLog {
client_ip?: string | null
request_path?: string
stream?: boolean
// Error observability context (endpoint + model mapping)
inbound_endpoint?: string
upstream_endpoint?: string
requested_model?: string
upstream_model?: string
request_type?: number | null
}
export interface OpsErrorDetail extends OpsErrorLog {

View File

@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
import type { CustomMenuItem } from '@/types'
import type { CustomMenuItem, CustomEndpoint } from '@/types'
export interface DefaultSubscriptionSetting {
group_id: number
@ -43,6 +43,7 @@ export interface SystemSettings {
sora_client_enabled: boolean
backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
// SMTP settings
smtp_host: string
smtp_port: number
@ -112,6 +113,7 @@ export interface UpdateSettingsRequest {
sora_client_enabled?: boolean
backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[]
smtp_host?: string
smtp_port?: number
smtp_username?: string

View File

@ -661,6 +661,43 @@
</div>
</div>
<!-- OpenAI OAuth WS mode -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-ws-mode-label"
class="input-label mb-0"
for="bulk-edit-openai-ws-mode-enabled"
>
{{ t('admin.accounts.openai.wsMode') }}
</label>
<input
v-model="enableOpenAIWSMode"
id="bulk-edit-openai-ws-mode-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-ws-mode"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-ws-mode"
:class="!enableOpenAIWSMode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.wsModeDesc') }}
</p>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t(openAIWSModeConcurrencyHintKey) }}
</p>
<Select
v-model="openaiOAuthResponsesWebSocketV2Mode"
data-testid="bulk-edit-openai-ws-mode-select"
:options="openAIWSModeOptions"
aria-labelledby="bulk-edit-openai-ws-mode-label"
/>
</div>
</div>
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
@ -883,6 +920,13 @@ import {
buildModelMappingObject as buildModelMappingPayload,
getPresetMappingsByPlatform
} from '@/composables/useModelWhitelist'
import {
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
resolveOpenAIWSModeConcurrencyHintKey
} from '@/utils/openaiWsMode'
import type { OpenAIWSMode } from '@/utils/openaiWsMode'
interface Props {
show: boolean
accountIds: number[]
@ -913,6 +957,15 @@ const allOpenAIPassthroughCapable = computed(() => {
)
})
const allOpenAIOAuth = computed(() => {
return (
props.selectedPlatforms.length === 1 &&
props.selectedPlatforms[0] === 'openai' &&
props.selectedTypes.length > 0 &&
props.selectedTypes.every(t => t === 'oauth')
)
})
// Anthropic OAuth/SetupTokenRPM
const allAnthropicOAuthOrSetupToken = computed(() => {
return (
@ -957,6 +1010,7 @@ const enableRateMultiplier = ref(false)
const enableStatus = ref(false)
const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
const enableRpmLimit = ref(false)
// State - field values
@ -979,6 +1033,7 @@ const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref<number | null>(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@ -1005,10 +1060,19 @@ const statusOptions = computed(() => [
{ value: 'active', label: t('common.active') },
{ value: 'inactive', label: t('common.inactive') }
])
const isOpenAIModelRestrictionDisabled = computed(() =>
allOpenAIPassthroughCapable.value &&
enableOpenAIPassthrough.value &&
openaiPassthroughEnabled.value
const isOpenAIModelRestrictionDisabled = computed(
() =>
allOpenAIPassthroughCapable.value &&
enableOpenAIPassthrough.value &&
openaiPassthroughEnabled.value
)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
)
// Model mapping helpers
@ -1180,6 +1244,14 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.credentials = credentials
}
if (enableOpenAIWSMode.value) {
const extra = ensureExtra()
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
openaiOAuthResponsesWebSocketV2Mode.value
)
}
// RPM limit settings ( extra )
if (enableRpmLimit.value) {
const extra = ensureExtra()
@ -1269,6 +1341,7 @@ const handleSubmit = async () => {
enableRateMultiplier.value ||
enableStatus.value ||
enableGroups.value ||
enableOpenAIWSMode.value ||
enableRpmLimit.value ||
userMsgQueueMode.value !== null
@ -1361,6 +1434,7 @@ watch(
enableStatus.value = false
enableGroups.value = false
enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false
enableRpmLimit.value = false
// Reset all values
@ -1379,6 +1453,7 @@ watch(
rateMultiplier.value = 1
status.value = 'active'
groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
rpmLimitEnabled.value = false
bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered'

View File

@ -2504,6 +2504,7 @@
:allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
:show-mobile-refresh-token-option="form.platform === 'openai'"
:show-session-token-option="form.platform === 'sora'"
:show-access-token-option="form.platform === 'sora'"
:platform="form.platform"
@ -2511,6 +2512,7 @@
@generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken"
@validate-mobile-refresh-token="handleOpenAIValidateMobileRT"
@validate-session-token="handleValidateSessionToken"
@import-access-token="handleImportAccessToken"
/>
@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => {
}
// OpenAI RT
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
// OpenAI Mobile RT 使 client_id openai.SoraClientID
const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK'
// OpenAI/Sora RT
const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => {
const oauthClient = activeOpenAIOAuth.value
if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput
.split('\n')
.map((rt) => rt.trim())
@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
try {
const tokenInfo = await oauthClient.validateRefreshToken(
refreshTokens[i],
form.proxy_id
form.proxy_id,
clientId
)
if (!tokenInfo) {
failedCount++
@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
const credentials = oauthClient.buildCredentials(tokenInfo)
if (clientId) {
credentials.client_id = clientId
}
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra)
@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
}
// Generate account name with index for batch
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName
let openaiAccountId: string | number | undefined
@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
}
// RTCodex CLI client_id
const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt)
// Mobile RTSoraClientID
const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID)
// Sora ST
const handleSoraValidateST = async (sessionTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value

View File

@ -48,6 +48,17 @@
t(getOAuthKey('refreshTokenAuth'))
}}</span>
</label>
<label v-if="showMobileRefreshTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
value="mobile_refresh_token"
class="text-blue-600 focus:ring-blue-500"
/>
<span class="text-sm text-blue-900 dark:text-blue-200">{{
t('admin.accounts.oauth.openai.mobileRefreshTokenAuth', '手动输入 Mobile RT')
}}</span>
</label>
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
@ -73,8 +84,8 @@
</div>
</div>
<!-- Refresh Token Input (OpenAI / Antigravity) -->
<div v-if="inputMethod === 'refresh_token'" class="space-y-4">
<!-- Refresh Token Input (OpenAI / Antigravity / Mobile RT) -->
<div v-if="inputMethod === 'refresh_token' || inputMethod === 'mobile_refresh_token'" class="space-y-4">
<div
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
>
@ -759,6 +770,7 @@ interface Props {
methodLabel?: string
showCookieOption?: boolean // Whether to show cookie auto-auth option
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only)
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
showAccessTokenOption?: boolean // Whether to show access token input option (Sora only)
platform?: AccountPlatform // Platform type for different UI/text
@ -776,6 +788,7 @@ const props = withDefaults(defineProps<Props>(), {
methodLabel: 'Authorization Method',
showCookieOption: true,
showRefreshTokenOption: false,
showMobileRefreshTokenOption: false,
showSessionTokenOption: false,
showAccessTokenOption: false,
platform: 'anthropic',
@ -787,6 +800,7 @@ const emit = defineEmits<{
'exchange-code': [code: string]
'cookie-auth': [sessionKey: string]
'validate-refresh-token': [refreshToken: string]
'validate-mobile-refresh-token': [refreshToken: string]
'validate-session-token': [sessionToken: string]
'import-access-token': [accessToken: string]
'update:inputMethod': [method: AuthInputMethod]
@ -834,7 +848,7 @@ const oauthState = ref('')
const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
// Clipboard
const { copied, copyToClipboard } = useClipboard()
@ -945,7 +959,11 @@ const handleCookieAuth = () => {
const handleValidateRefreshToken = () => {
if (refreshTokenInput.value.trim()) {
emit('validate-refresh-token', refreshTokenInput.value.trim())
if (inputMethod.value === 'mobile_refresh_token') {
emit('validate-mobile-refresh-token', refreshTokenInput.value.trim())
} else {
emit('validate-refresh-token', refreshTokenInput.value.trim())
}
}
}

View File

@ -149,6 +149,35 @@ describe('BulkEditAccountModal', () => {
})
})
it('OpenAI OAuth 批量编辑应提交 OAuth 专属 WS mode 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
})
await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true)
await wrapper.get('[data-testid="bulk-edit-openai-ws-mode-select"]').setValue('passthrough')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
openai_oauth_responses_websockets_v2_mode: 'passthrough',
openai_oauth_responses_websockets_v2_enabled: true
}
})
})
it('OpenAI API Key 批量编辑不显示 WS mode 入口', () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['apikey']
})
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
})
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],

View File

@ -10,6 +10,7 @@
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
<Select :model-value="filters.privacy_mode" class="w-40" :options="privacyOpts" @update:model-value="updatePrivacyMode" @change="$emit('change')" />
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
</div>
</template>
@ -22,10 +23,18 @@ const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); co
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
const privacyOpts = computed(() => [
{ value: '', label: t('admin.accounts.allPrivacyModes') },
{ value: '__unset__', label: t('admin.accounts.privacyUnset') },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
const gOpts = computed(() => [
{ value: '', label: t('admin.accounts.allGroups') },
{ value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') },

View File

@ -0,0 +1,56 @@
import { describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import AccountTableFilters from '../AccountTableFilters.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('AccountTableFilters', () => {
it('renders privacy mode options and emits privacy_mode updates', async () => {
const wrapper = mount(AccountTableFilters, {
props: {
searchQuery: '',
filters: {
platform: '',
type: '',
status: '',
group: '',
privacy_mode: ''
},
groups: []
},
global: {
stubs: {
SearchInput: {
template: '<div />'
},
Select: {
props: ['modelValue', 'options'],
emits: ['update:modelValue', 'change'],
template: '<div class="select-stub" :data-options="JSON.stringify(options)" />'
}
}
}
})
const selects = wrapper.findAll('.select-stub')
expect(selects).toHaveLength(5)
const privacyOptions = JSON.parse(selects[3].attributes('data-options'))
expect(privacyOptions).toEqual([
{ value: '', label: 'admin.accounts.allPrivacyModes' },
{ value: '__unset__', label: 'admin.accounts.privacyUnset' },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
})
})

View File

@ -0,0 +1,141 @@
<script setup lang="ts">
import { computed, onBeforeUnmount, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { useClipboard } from '@/composables/useClipboard'
import type { CustomEndpoint } from '@/types'
const props = defineProps<{
apiBaseUrl: string
customEndpoints: CustomEndpoint[]
}>()
const { t } = useI18n()
const { copyToClipboard } = useClipboard()
const copiedEndpoint = ref<string | null>(null)
let copiedResetTimer: number | undefined
const allEndpoints = computed(() => {
const items: Array<{ name: string; endpoint: string; description: string; isDefault: boolean }> = []
if (props.apiBaseUrl) {
items.push({
name: t('keys.endpoints.title'),
endpoint: props.apiBaseUrl,
description: '',
isDefault: true,
})
}
for (const ep of props.customEndpoints) {
items.push({ ...ep, isDefault: false })
}
return items
})
async function copy(url: string) {
const success = await copyToClipboard(url, t('keys.endpoints.copied'))
if (!success) return
copiedEndpoint.value = url
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
copiedResetTimer = window.setTimeout(() => {
if (copiedEndpoint.value === url) {
copiedEndpoint.value = null
}
}, 1800)
}
function tooltipHint(endpoint: string): string {
return copiedEndpoint.value === endpoint
? t('keys.endpoints.copiedHint')
: t('keys.endpoints.clickToCopy')
}
function speedTestUrl(endpoint: string): string {
return `https://www.tcptest.cn/http/${encodeURIComponent(endpoint)}`
}
onBeforeUnmount(() => {
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
})
</script>
<template>
<div v-if="allEndpoints.length > 0" class="flex flex-wrap gap-2">
<div
v-for="(item, index) in allEndpoints"
:key="index"
class="flex items-center gap-1.5 rounded-lg border border-gray-200 bg-white px-2.5 py-1.5 text-xs transition-colors hover:border-primary-200 dark:border-dark-600 dark:bg-dark-800 dark:hover:border-primary-700"
>
<span class="font-medium text-gray-600 dark:text-gray-300">{{ item.name }}</span>
<span
v-if="item.isDefault"
class="rounded bg-primary-50 px-1 py-px text-[10px] font-medium leading-tight text-primary-600 dark:bg-primary-900/30 dark:text-primary-400"
>{{ t('keys.endpoints.default') }}</span>
<span class="text-gray-300 dark:text-dark-500">|</span>
<div class="group/endpoint relative flex items-center gap-1.5">
<div
class="pointer-events-none absolute bottom-full left-1/2 z-20 mb-2 w-max max-w-[24rem] -translate-x-1/2 translate-y-1 rounded-xl border border-slate-200 bg-white px-3 py-2.5 text-left opacity-0 shadow-[0_14px_36px_-20px_rgba(15,23,42,0.35)] ring-1 ring-slate-200/80 transition-all duration-150 group-hover/endpoint:translate-y-0 group-hover/endpoint:opacity-100 group-focus-within/endpoint:translate-y-0 group-focus-within/endpoint:opacity-100 dark:border-slate-700 dark:bg-slate-900 dark:ring-slate-700/70"
>
<p
v-if="item.description"
class="max-w-[24rem] break-words text-xs leading-5 text-slate-600 dark:text-slate-200"
>
{{ item.description }}
</p>
<p
class="flex items-center gap-1.5 text-[11px] leading-4 text-primary-600 dark:text-primary-300"
:class="item.description ? 'mt-1.5' : ''"
>
<span class="h-1.5 w-1.5 rounded-full bg-primary-500 dark:bg-primary-300"></span>
{{ tooltipHint(item.endpoint) }}
</p>
<div class="absolute left-1/2 top-full h-3 w-3 -translate-x-1/2 -translate-y-1/2 rotate-45 border-b border-r border-slate-200 bg-white dark:border-slate-700 dark:bg-slate-900"></div>
</div>
<code
class="cursor-pointer font-mono text-gray-500 decoration-gray-400 decoration-dashed underline-offset-2 hover:text-primary-600 hover:underline focus:text-primary-600 focus:underline focus:outline-none dark:text-gray-400 dark:decoration-gray-500 dark:hover:text-primary-400 dark:focus:text-primary-400"
role="button"
tabindex="0"
@click="copy(item.endpoint)"
@keydown.enter.prevent="copy(item.endpoint)"
@keydown.space.prevent="copy(item.endpoint)"
>{{ item.endpoint }}</code>
<button
type="button"
class="rounded p-0.5 transition-colors"
:class="copiedEndpoint === item.endpoint
? 'text-emerald-500 dark:text-emerald-400'
: 'text-gray-400 hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400'"
:aria-label="tooltipHint(item.endpoint)"
@click="copy(item.endpoint)"
>
<svg v-if="copiedEndpoint === item.endpoint" class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2.2">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
<svg v-else class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z" />
</svg>
</button>
<a
:href="speedTestUrl(item.endpoint)"
target="_blank"
rel="noopener noreferrer"
class="rounded p-0.5 text-gray-400 transition-colors hover:text-amber-500 dark:text-gray-500 dark:hover:text-amber-400"
:title="t('keys.endpoints.speedTest')"
>
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M13 10V3L4 14h7v7l9-11h-7z" />
</svg>
</a>
</div>
</div>
</div>
</template>

View File

@ -0,0 +1,69 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
const copyToClipboard = vi.fn().mockResolvedValue(true)
const messages: Record<string, string> = {
'keys.endpoints.title': 'API 端点',
'keys.endpoints.default': '默认',
'keys.endpoints.copied': '已复制',
'keys.endpoints.copiedHint': '已复制到剪贴板',
'keys.endpoints.clickToCopy': '点击可复制此端点',
'keys.endpoints.speedTest': '测速',
}
vi.mock('vue-i18n', () => ({
useI18n: () => ({
t: (key: string) => messages[key] ?? key,
}),
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copyToClipboard,
}),
}))
import EndpointPopover from '../EndpointPopover.vue'
describe('EndpointPopover', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('将说明提示渲染到 URL 上方而不是旧的 title 图标上', () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [
{
name: '备用线路',
endpoint: 'https://backup.example.com/v1',
description: '自定义说明',
},
],
},
})
expect(wrapper.text()).toContain('自定义说明')
expect(wrapper.text()).toContain('点击可复制此端点')
expect(wrapper.find('[role="button"]').attributes('title')).toBeUndefined()
expect(wrapper.find('[title="自定义说明"]').exists()).toBe(false)
})
it('点击 URL 后会复制并切换为已复制提示', async () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [],
},
})
await wrapper.find('[role="button"]').trigger('click')
await flushPromises()
expect(copyToClipboard).toHaveBeenCalledWith('https://default.example.com/v1', '已复制')
expect(wrapper.text()).toContain('已复制到剪贴板')
expect(wrapper.find('button[aria-label="已复制到剪贴板"]').exists()).toBe(true)
})
})

View File

@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
export type AddMethod = 'oauth' | 'setup-token'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token'
export interface OAuthState {
authUrl: string

View File

@ -13,6 +13,8 @@ export interface OpenAITokenInfo {
scope?: string
email?: string
name?: string
plan_type?: string
privacy_mode?: string
// OpenAI specific IDs (extracted from ID Token)
chatgpt_account_id?: string
chatgpt_user_id?: string
@ -126,9 +128,11 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
}
// Validate refresh token and get full token info
// clientId: 指定 OAuth client_id用于第三方渠道获取的 RT如 app_LlGpXReQgckcGGUo2JrYvtJK
const validateRefreshToken = async (
refreshToken: string,
proxyId?: number | null
proxyId?: number | null,
clientId?: string
): Promise<OpenAITokenInfo | null> => {
if (!refreshToken.trim()) {
error.value = 'Missing refresh token'
@ -143,11 +147,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
refreshToken.trim(),
proxyId,
`${endpointPrefix}/refresh-token`
`${endpointPrefix}/refresh-token`,
clientId
)
return tokenInfo as OpenAITokenInfo
} catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate refresh token'
error.value = err.response?.data?.detail || err.message || 'Failed to validate refresh token'
appStore.showError(error.value)
return null
} finally {
@ -182,22 +187,23 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
}
}
// Build credentials for OpenAI OAuth account
// Build credentials for OpenAI OAuth account (aligned with backend BuildAccountCredentials)
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
const creds: Record<string, unknown> = {
access_token: tokenInfo.access_token,
refresh_token: tokenInfo.refresh_token,
token_type: tokenInfo.token_type,
expires_in: tokenInfo.expires_in,
expires_at: tokenInfo.expires_at,
scope: tokenInfo.scope
expires_at: tokenInfo.expires_at
}
if (tokenInfo.client_id) {
creds.client_id = tokenInfo.client_id
// 仅在返回了新的 refresh_token 时才写入,防止用空值覆盖已有令牌
if (tokenInfo.refresh_token) {
creds.refresh_token = tokenInfo.refresh_token
}
if (tokenInfo.id_token) {
creds.id_token = tokenInfo.id_token
}
if (tokenInfo.email) {
creds.email = tokenInfo.email
}
// Include OpenAI specific IDs (required for forwarding)
if (tokenInfo.chatgpt_account_id) {
creds.chatgpt_account_id = tokenInfo.chatgpt_account_id
}
@ -207,6 +213,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.organization_id) {
creds.organization_id = tokenInfo.organization_id
}
if (tokenInfo.plan_type) {
creds.plan_type = tokenInfo.plan_type
}
if (tokenInfo.client_id) {
creds.client_id = tokenInfo.client_id
}
return creds
}
@ -220,6 +232,9 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
if (tokenInfo.name) {
extra.name = tokenInfo.name
}
if (tokenInfo.privacy_mode) {
extra.privacy_mode = tokenInfo.privacy_mode
}
return Object.keys(extra).length > 0 ? extra : undefined
}

View File

@ -533,6 +533,14 @@ export default {
title: 'API Keys',
description: 'Manage your API keys and access tokens',
searchPlaceholder: 'Search name or key...',
endpoints: {
title: 'API Endpoints',
default: 'Default',
copied: 'Copied',
copiedHint: 'Copied to clipboard',
clickToCopy: 'Click to copy this endpoint',
speedTest: 'Speed Test',
},
allGroups: 'All Groups',
allStatus: 'All Status',
createKey: 'Create API Key',
@ -1971,6 +1979,8 @@ export default {
expiresAt: 'Expires At',
actions: 'Actions'
},
allPrivacyModes: 'All Privacy States',
privacyUnset: 'Unset',
privacyTrainingOff: 'Training data sharing disabled',
privacyCfBlocked: 'Blocked by Cloudflare, training may still be on',
privacyFailed: 'Failed to disable training',
@ -3486,7 +3496,12 @@ export default {
typeRequest: 'Request',
typeAuth: 'Auth',
typeRouting: 'Routing',
typeInternal: 'Internal'
typeInternal: 'Internal',
endpoint: 'Endpoint',
requestType: 'Type',
requestTypeSync: 'Sync',
requestTypeStream: 'Stream',
requestTypeWs: 'WS'
},
// Error Details Modal
errorDetails: {
@ -3572,6 +3587,16 @@ export default {
latency: 'Request Duration',
businessLimited: 'Business Limited',
requestPath: 'Request Path',
inboundEndpoint: 'Inbound Endpoint',
upstreamEndpoint: 'Upstream Endpoint',
requestedModel: 'Requested Model',
upstreamModel: 'Upstream Model',
requestType: 'Request Type',
requestTypeUnknown: 'Unknown',
requestTypeSync: 'Sync',
requestTypeStream: 'Stream',
requestTypeWs: 'WebSocket',
modelMapping: 'Model Mapping',
timings: 'Timings',
auth: 'Auth',
routing: 'Routing',
@ -4162,6 +4187,18 @@ export default {
apiBaseUrlPlaceholder: 'https://api.example.com',
apiBaseUrlHint:
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
customEndpoints: {
title: 'Custom Endpoints',
description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page',
itemLabel: 'Endpoint #{n}',
name: 'Name',
namePlaceholder: 'e.g., OpenAI Compatible',
endpointUrl: 'Endpoint URL',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: 'Description',
descriptionPlaceholder: 'e.g., Supports OpenAI format requests',
add: 'Add Endpoint',
},
contactInfo: 'Contact Info',
contactInfoPlaceholder: 'e.g., QQ: 123456789',
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',

View File

@ -533,6 +533,14 @@ export default {
title: 'API 密钥',
description: '管理您的 API 密钥和访问令牌',
searchPlaceholder: '搜索名称或Key...',
endpoints: {
title: 'API 端点',
default: '默认',
copied: '已复制',
copiedHint: '已复制到剪贴板',
clickToCopy: '点击可复制此端点',
speedTest: '测速',
},
allGroups: '全部分组',
allStatus: '全部状态',
createKey: '创建密钥',
@ -2009,6 +2017,8 @@ export default {
expiresAt: '过期时间',
actions: '操作'
},
allPrivacyModes: '全部Privacy状态',
privacyUnset: '未设置',
privacyTrainingOff: '已关闭训练数据共享',
privacyCfBlocked: '被 Cloudflare 拦截,训练可能仍开启',
privacyFailed: '关闭训练数据共享失败',
@ -3651,7 +3661,12 @@ export default {
typeRequest: '请求',
typeAuth: '认证',
typeRouting: '路由',
typeInternal: '内部'
typeInternal: '内部',
endpoint: '端点',
requestType: '类型',
requestTypeSync: '同步',
requestTypeStream: '流式',
requestTypeWs: 'WS'
},
// Error Details Modal
errorDetails: {
@ -3737,6 +3752,16 @@ export default {
latency: '请求时长',
businessLimited: '业务限制',
requestPath: '请求路径',
inboundEndpoint: '入站端点',
upstreamEndpoint: '上游端点',
requestedModel: '请求模型',
upstreamModel: '上游模型',
requestType: '请求类型',
requestTypeUnknown: '未知',
requestTypeSync: '同步',
requestTypeStream: '流式',
requestTypeWs: 'WebSocket',
modelMapping: '模型映射',
timings: '时序信息',
auth: '认证',
routing: '路由',
@ -4324,6 +4349,18 @@ export default {
apiBaseUrl: 'API 端点地址',
apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址',
apiBaseUrlPlaceholder: 'https://api.example.com',
customEndpoints: {
title: '自定义端点',
description: '添加额外的 API 端点地址用户可在「API Keys」页面快速复制',
itemLabel: '端点 #{n}',
name: '名称',
namePlaceholder: '如OpenAI Compatible',
endpointUrl: '端点地址',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: '介绍',
descriptionPlaceholder: '如:支持 OpenAI 格式请求',
add: '添加端点',
},
contactInfo: '客服联系方式',
contactInfoPlaceholder: '例如QQ: 123456789',
contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置',

View File

@ -330,6 +330,7 @@ export const useAppStore = defineStore('app', () => {
purchase_subscription_enabled: false,
purchase_subscription_url: '',
custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false,
sora_client_enabled: false,
backend_mode_enabled: false,

View File

@ -84,6 +84,12 @@ export interface CustomMenuItem {
sort_order: number
}
export interface CustomEndpoint {
name: string
endpoint: string
description: string
}
export interface PublicSettings {
registration_enabled: boolean
email_verify_enabled: boolean
@ -104,6 +110,7 @@ export interface PublicSettings {
purchase_subscription_enabled: boolean
purchase_subscription_url: string
custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean
sora_client_enabled: boolean
backend_mode_enabled: boolean

View File

@ -581,7 +581,7 @@ const {
handlePageSizeChange: baseHandlePageSizeChange
} = useTableLoader<Account, any>({
fetchFn: adminAPI.accounts.list,
initialParams: { platform: '', type: '', status: '', group: '', search: '' }
initialParams: { platform: '', type: '', status: '', privacy_mode: '', group: '', search: '' }
})
const {
@ -758,6 +758,7 @@ const refreshAccountsIncrementally = async () => {
platform?: string
type?: string
status?: string
privacy_mode?: string
group?: string
search?: string

View File

@ -1248,6 +1248,81 @@
</p>
</div>
<!-- Custom Endpoints -->
<div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.title') }}
</label>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.description') }}
</p>
<div class="space-y-3">
<div
v-for="(ep, index) in form.custom_endpoints"
:key="index"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center justify-between">
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }}
</span>
<button
type="button"
class="rounded p-1 text-red-400 hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
@click="removeEndpoint(index)"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" /></svg>
</button>
</div>
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.name') }}
</label>
<input
v-model="ep.name"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.namePlaceholder')"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.endpointUrl') }}
</label>
<input
v-model="ep.endpoint"
type="url"
class="input font-mono text-sm"
:placeholder="t('admin.settings.site.customEndpoints.endpointUrlPlaceholder')"
/>
</div>
<div class="sm:col-span-2">
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.descriptionLabel') }}
</label>
<input
v-model="ep.description"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.descriptionPlaceholder')"
/>
</div>
</div>
</div>
</div>
<button
type="button"
class="mt-3 flex w-full items-center justify-center gap-2 rounded-lg border-2 border-dashed border-gray-300 px-4 py-2.5 text-sm text-gray-500 transition-colors hover:border-primary-400 hover:text-primary-600 dark:border-dark-600 dark:text-gray-400 dark:hover:border-primary-500 dark:hover:text-primary-400"
@click="addEndpoint"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M12 4v16m8-8H4" /></svg>
{{ t('admin.settings.site.customEndpoints.add') }}
</button>
</div>
<!-- Contact Info -->
<div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
@ -1580,7 +1655,7 @@
<button
type="button"
@click="testSmtpConnection"
:disabled="testingSmtp"
:disabled="testingSmtp || loadFailed"
class="btn btn-secondary btn-sm"
>
<svg v-if="testingSmtp" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
@ -1650,6 +1725,11 @@
v-model="form.smtp_password"
type="password"
class="input"
autocomplete="new-password"
autocapitalize="off"
spellcheck="false"
@keydown="smtpPasswordManuallyEdited = true"
@paste="smtpPasswordManuallyEdited = true"
:placeholder="
form.smtp_password_configured
? t('admin.settings.smtp.passwordConfiguredPlaceholder')
@ -1732,7 +1812,7 @@
<button
type="button"
@click="sendTestEmail"
:disabled="sendingTestEmail || !testEmailAddress"
:disabled="sendingTestEmail || !testEmailAddress || loadFailed"
class="btn btn-secondary"
>
<svg
@ -1778,7 +1858,7 @@
<!-- Save Button -->
<div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end">
<button type="submit" :disabled="saving" class="btn btn-primary">
<button type="submit" :disabled="saving || loadFailed" class="btn btn-primary">
<svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle
class="opacity-25"
@ -1849,9 +1929,11 @@ const settingsTabs = [
const { copyToClipboard } = useClipboard()
const loading = ref(true)
const loadFailed = ref(false)
const saving = ref(false)
const testingSmtp = ref(false)
const sendingTestEmail = ref(false)
const smtpPasswordManuallyEdited = ref(false)
const testEmailAddress = ref('')
const registrationEmailSuffixWhitelistTags = ref<string[]>([])
const registrationEmailSuffixWhitelistDraft = ref('')
@ -1945,6 +2027,7 @@ const form = reactive<SettingsForm>({
purchase_subscription_url: '',
sora_client_enabled: false,
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
frontend_url: '',
smtp_host: '',
smtp_port: 587,
@ -2114,8 +2197,18 @@ function moveMenuItem(index: number, direction: -1 | 1) {
})
}
// Custom endpoint management
function addEndpoint() {
form.custom_endpoints.push({ name: '', endpoint: '', description: '' })
}
function removeEndpoint(index: number) {
form.custom_endpoints.splice(index, 1)
}
async function loadSettings() {
loading.value = true
loadFailed.value = false
try {
const settings = await adminAPI.settings.getSettings()
Object.assign(form, settings)
@ -2133,9 +2226,11 @@ async function loadSettings() {
)
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
} catch (error: any) {
loadFailed.value = true
appStore.showError(
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
)
@ -2253,6 +2348,7 @@ async function saveSettings() {
purchase_subscription_url: form.purchase_subscription_url,
sora_client_enabled: form.sora_client_enabled,
custom_menu_items: form.custom_menu_items,
custom_endpoints: form.custom_endpoints,
frontend_url: form.frontend_url,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
@ -2286,6 +2382,7 @@ async function saveSettings() {
)
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
// Refresh cached settings so sidebar/header update immediately
@ -2304,11 +2401,12 @@ async function saveSettings() {
async function testSmtpConnection() {
testingSmtp.value = true
try {
const smtpPasswordForTest = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
const result = await adminAPI.settings.testSmtpConnection({
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,
smtp_password: form.smtp_password,
smtp_password: smtpPasswordForTest,
smtp_use_tls: form.smtp_use_tls
})
// API returns { message: "..." } on success, errors are thrown as exceptions
@ -2330,12 +2428,13 @@ async function sendTestEmail() {
sendingTestEmail.value = true
try {
const smtpPasswordForSend = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
const result = await adminAPI.settings.sendTestEmail({
email: testEmailAddress.value,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,
smtp_password: form.smtp_password,
smtp_password: smtpPasswordForSend,
smtp_from_email: form.smtp_from_email,
smtp_from_name: form.smtp_from_name,
smtp_use_tls: form.smtp_use_tls

View File

@ -59,7 +59,28 @@
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.model') }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
{{ detail.model || '—' }}
<template v-if="hasModelMapping(detail)">
<span class="font-mono">{{ detail.requested_model }}</span>
<span class="mx-1 text-gray-400"></span>
<span class="font-mono text-primary-600 dark:text-primary-400">{{ detail.upstream_model }}</span>
</template>
<template v-else>
{{ displayModel(detail) || '—' }}
</template>
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.inboundEndpoint') }}</div>
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
{{ detail.inbound_endpoint || '—' }}
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.upstreamEndpoint') }}</div>
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
{{ detail.upstream_endpoint || '—' }}
</div>
</div>
@ -72,6 +93,13 @@
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.requestType') }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
{{ formatRequestTypeLabel(detail.request_type) }}
</div>
</div>
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.message') }}</div>
<div class="mt-1 truncate text-sm font-medium text-gray-900 dark:text-white" :title="detail.message">
@ -213,6 +241,31 @@ function isUpstreamError(d: OpsErrorDetail | null): boolean {
return phase === 'upstream' && owner === 'provider'
}
function formatRequestTypeLabel(type: number | null | undefined): string {
switch (type) {
case 1: return t('admin.ops.errorDetail.requestTypeSync')
case 2: return t('admin.ops.errorDetail.requestTypeStream')
case 3: return t('admin.ops.errorDetail.requestTypeWs')
default: return t('admin.ops.errorDetail.requestTypeUnknown')
}
}
function hasModelMapping(d: OpsErrorDetail | null): boolean {
if (!d) return false
const requested = String(d.requested_model || '').trim()
const upstream = String(d.upstream_model || '').trim()
return !!requested && !!upstream && requested !== upstream
}
function displayModel(d: OpsErrorDetail | null): string {
if (!d) return ''
const upstream = String(d.upstream_model || '').trim()
if (upstream) return upstream
const requested = String(d.requested_model || '').trim()
if (requested) return requested
return String(d.model || '').trim()
}
const correlatedUpstream = ref<OpsErrorDetail[]>([])
const correlatedUpstreamLoading = ref(false)

View File

@ -17,6 +17,9 @@
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.type') }}
</th>
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.endpoint') }}
</th>
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
{{ t('admin.ops.errorLog.platform') }}
</th>
@ -42,7 +45,7 @@
</thead>
<tbody class="divide-y divide-gray-100 dark:divide-dark-700">
<tr v-if="rows.length === 0">
<td colspan="9" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500">
<td colspan="10" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500">
{{ t('admin.ops.errorLog.noErrors') }}
</td>
</tr>
@ -74,6 +77,18 @@
</span>
</td>
<!-- Endpoint -->
<td class="px-4 py-2">
<div class="max-w-[160px]">
<el-tooltip v-if="log.inbound_endpoint" :content="formatEndpointTooltip(log)" placement="top" :show-after="500">
<span class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
{{ log.inbound_endpoint }}
</span>
</el-tooltip>
<span v-else class="text-xs text-gray-400">-</span>
</div>
</td>
<!-- Platform -->
<td class="whitespace-nowrap px-4 py-2">
<span class="inline-flex items-center rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold uppercase text-gray-600 dark:bg-dark-700 dark:text-gray-300">
@ -83,11 +98,22 @@
<!-- Model -->
<td class="px-4 py-2">
<div class="max-w-[120px] truncate" :title="log.model">
<span v-if="log.model" class="font-mono text-[11px] text-gray-700 dark:text-gray-300">
{{ log.model }}
</span>
<span v-else class="text-xs text-gray-400">-</span>
<div class="max-w-[160px]">
<template v-if="hasModelMapping(log)">
<el-tooltip :content="modelMappingTooltip(log)" placement="top" :show-after="500">
<span class="flex items-center gap-1 truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
<span class="truncate">{{ log.requested_model }}</span>
<span class="flex-shrink-0 text-gray-400"></span>
<span class="truncate text-primary-600 dark:text-primary-400">{{ log.upstream_model }}</span>
</span>
</el-tooltip>
</template>
<template v-else>
<span v-if="displayModel(log)" class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300" :title="displayModel(log)">
{{ displayModel(log) }}
</span>
<span v-else class="text-xs text-gray-400">-</span>
</template>
</div>
</td>
@ -138,6 +164,12 @@
>
{{ log.severity }}
</span>
<span
v-if="log.request_type != null && log.request_type > 0"
class="rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold text-gray-600 dark:bg-dark-700 dark:text-gray-300"
>
{{ formatRequestType(log.request_type) }}
</span>
</div>
</td>
@ -193,6 +225,44 @@ function isUpstreamRow(log: OpsErrorLog): boolean {
return phase === 'upstream' && owner === 'provider'
}
function formatEndpointTooltip(log: OpsErrorLog): string {
const parts: string[] = []
if (log.inbound_endpoint) parts.push(`Inbound: ${log.inbound_endpoint}`)
if (log.upstream_endpoint) parts.push(`Upstream: ${log.upstream_endpoint}`)
return parts.join('\n') || ''
}
function hasModelMapping(log: OpsErrorLog): boolean {
const requested = String(log.requested_model || '').trim()
const upstream = String(log.upstream_model || '').trim()
return !!requested && !!upstream && requested !== upstream
}
function modelMappingTooltip(log: OpsErrorLog): string {
const requested = String(log.requested_model || '').trim()
const upstream = String(log.upstream_model || '').trim()
if (!requested && !upstream) return ''
if (requested && upstream) return `${requested}${upstream}`
return upstream || requested
}
function displayModel(log: OpsErrorLog): string {
const upstream = String(log.upstream_model || '').trim()
if (upstream) return upstream
const requested = String(log.requested_model || '').trim()
if (requested) return requested
return String(log.model || '').trim()
}
function formatRequestType(type: number | null | undefined): string {
switch (type) {
case 1: return t('admin.ops.errorLog.requestTypeSync')
case 2: return t('admin.ops.errorLog.requestTypeStream')
case 3: return t('admin.ops.errorLog.requestTypeWs')
default: return ''
}
}
function getTypeBadge(log: OpsErrorLog): { label: string; className: string } {
const phase = String(log.phase || '').toLowerCase()
const owner = String(log.error_owner || '').toLowerCase()
@ -263,4 +333,4 @@ function formatSmartMessage(msg: string): string {
return msg.length > 200 ? msg.substring(0, 200) + '...' : msg
}
</script>
</script>

View File

@ -344,7 +344,7 @@ onMounted(async () => {
<div class="text-xs font-semibold text-gray-700 dark:text-gray-200">运行时日志配置实时生效</div>
<span v-if="runtimeLoading" class="text-xs text-gray-500">加载中...</span>
</div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-6">
<div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
<label class="text-xs text-gray-600 dark:text-gray-300">
级别
<select v-model="runtimeConfig.level" class="input mt-1">
@ -374,21 +374,27 @@ onMounted(async () => {
保留天数
<input v-model.number="runtimeConfig.retention_days" type="number" min="1" max="3650" class="input mt-1" />
</label>
<div class="flex items-end gap-2">
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
<input v-model="runtimeConfig.caller" type="checkbox" />
caller
</label>
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
<input v-model="runtimeConfig.enable_sampling" type="checkbox" />
sampling
</label>
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
{{ runtimeSaving ? '保存中...' : '保存并生效' }}
</button>
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
回滚默认值
</button>
<div class="md:col-span-2 xl:col-span-6">
<div class="grid gap-3 lg:grid-cols-[minmax(0,1fr)_auto] lg:items-end">
<div class="flex flex-wrap items-center gap-x-4 gap-y-2">
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
<input v-model="runtimeConfig.caller" type="checkbox" />
caller
</label>
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
<input v-model="runtimeConfig.enable_sampling" type="checkbox" />
sampling
</label>
</div>
<div class="flex flex-wrap items-center gap-2 lg:justify-end">
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
{{ runtimeSaving ? '保存中...' : '保存并生效' }}
</button>
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
回滚默认值
</button>
</div>
</div>
</div>
</div>
<p v-if="health.last_error" class="mt-2 text-xs text-red-600 dark:text-red-400">最近写入错误{{ health.last_error }}</p>

View File

@ -2,24 +2,31 @@
<AppLayout>
<TablePageLayout>
<template #filters>
<div class="flex flex-wrap items-center gap-3">
<SearchInput
v-model="filterSearch"
:placeholder="t('keys.searchPlaceholder')"
class="w-full sm:w-64"
@search="onFilterChange"
/>
<Select
:model-value="filterGroupId"
class="w-40"
:options="groupFilterOptions"
@update:model-value="onGroupFilterChange"
/>
<Select
:model-value="filterStatus"
class="w-40"
:options="statusFilterOptions"
@update:model-value="onStatusFilterChange"
<div class="flex flex-col gap-3">
<div class="flex flex-wrap items-center gap-3">
<SearchInput
v-model="filterSearch"
:placeholder="t('keys.searchPlaceholder')"
class="w-full sm:w-64"
@search="onFilterChange"
/>
<Select
:model-value="filterGroupId"
class="w-40"
:options="groupFilterOptions"
@update:model-value="onGroupFilterChange"
/>
<Select
:model-value="filterStatus"
class="w-40"
:options="statusFilterOptions"
@update:model-value="onStatusFilterChange"
/>
</div>
<EndpointPopover
v-if="publicSettings?.api_base_url || (publicSettings?.custom_endpoints?.length ?? 0) > 0"
:api-base-url="publicSettings?.api_base_url || ''"
:custom-endpoints="publicSettings?.custom_endpoints || []"
/>
</div>
</template>
@ -1050,6 +1057,7 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import SearchInput from '@/components/common/SearchInput.vue'
import Icon from '@/components/icons/Icon.vue'
import UseKeyModal from '@/components/keys/UseKeyModal.vue'
import EndpointPopover from '@/components/keys/EndpointPopover.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'