Merge branch 'Wei-Shaw:main' into main
This commit is contained in:
commit
fa68cbad1b
@ -61,6 +61,9 @@ temp/
|
||||
deploy/install.sh
|
||||
deploy/sub2api.service
|
||||
deploy/sub2api-sudoers
|
||||
deploy/data/
|
||||
deploy/postgres_data/
|
||||
deploy/redis_data/
|
||||
|
||||
# GoReleaser
|
||||
.goreleaser.yaml
|
||||
|
||||
@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
|
||||
@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
|
||||
@ -233,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
|
||||
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
|
||||
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
|
||||
|
||||
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
|
||||
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
|
||||
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
|
||||
req.SMTPHost = previousSettings.SMTPHost
|
||||
req.SMTPPort = previousSettings.SMTPPort
|
||||
req.SMTPUsername = previousSettings.SMTPUsername
|
||||
req.SMTPFrom = previousSettings.SMTPFrom
|
||||
req.SMTPFromName = previousSettings.SMTPFromName
|
||||
req.SMTPUseTLS = previousSettings.SMTPUseTLS
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
// 检查必填字段
|
||||
@ -881,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
@ -897,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
|
||||
var savedConfig *service.SMTPConfig
|
||||
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
|
||||
savedConfig = cfg
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
if req.SMTPHost == "" && savedConfig != nil {
|
||||
req.SMTPHost = savedConfig.Host
|
||||
}
|
||||
if req.SMTPPort <= 0 {
|
||||
if savedConfig != nil && savedConfig.Port > 0 {
|
||||
req.SMTPPort = savedConfig.Port
|
||||
} else {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
}
|
||||
if req.SMTPUsername == "" && savedConfig != nil {
|
||||
req.SMTPUsername = savedConfig.Username
|
||||
}
|
||||
password := strings.TrimSpace(req.SMTPPassword)
|
||||
if password == "" && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
if req.SMTPHost == "" {
|
||||
response.BadRequest(c, "SMTP host is required")
|
||||
return
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
@ -930,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
@ -948,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
|
||||
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
|
||||
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
|
||||
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
|
||||
|
||||
var savedConfig *service.SMTPConfig
|
||||
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
|
||||
savedConfig = cfg
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
if req.SMTPHost == "" && savedConfig != nil {
|
||||
req.SMTPHost = savedConfig.Host
|
||||
}
|
||||
if req.SMTPPort <= 0 {
|
||||
if savedConfig != nil && savedConfig.Port > 0 {
|
||||
req.SMTPPort = savedConfig.Port
|
||||
} else {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
}
|
||||
if req.SMTPUsername == "" && savedConfig != nil {
|
||||
req.SMTPUsername = savedConfig.Username
|
||||
}
|
||||
password := strings.TrimSpace(req.SMTPPassword)
|
||||
if password == "" && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
if req.SMTPFrom == "" && savedConfig != nil {
|
||||
req.SMTPFrom = savedConfig.From
|
||||
}
|
||||
if req.SMTPFromName == "" && savedConfig != nil {
|
||||
req.SMTPFromName = savedConfig.FromName
|
||||
}
|
||||
if req.SMTPHost == "" {
|
||||
response.BadRequest(c, "SMTP host is required")
|
||||
return
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
|
||||
@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
289
backend/internal/handler/gateway_handler_chat_completions.go
Normal file
289
backend/internal/handler/gateway_handler_chat_completions.go
Normal file
@ -0,0 +1,289 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
|
||||
// POST /v1/chat/completions
|
||||
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
|
||||
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
|
||||
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// Claude Code only restriction
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleCCFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.cc.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.cc.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
|
||||
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
|
||||
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
295
backend/internal/handler/gateway_handler_responses.go
Normal file
295
backend/internal/handler/gateway_handler_responses.go
Normal file
@ -0,0 +1,295 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
|
||||
// POST /v1/responses
|
||||
// This converts Responses API requests to Anthropic format, forwards to Anthropic
|
||||
// upstream, and converts responses back to Responses format.
|
||||
func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.responses",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream using gjson (like OpenAI handler)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// Claude Code only restriction:
|
||||
// /v1/responses is never a Claude Code endpoint.
|
||||
// When claude_code_only is enabled, this endpoint is rejected.
|
||||
// The existing service-layer checkClaudeCodeRestriction handles degradation
|
||||
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
|
||||
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||
"This group is restricted to Claude Code clients (/v1/messages only)")
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
// 1. Acquire user concurrency slot
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.responsesErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request for session hash
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
|
||||
if parsedReq == nil {
|
||||
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
||||
}
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default:
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 4. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// Can't failover if streaming content already sent
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleResponsesFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
}
|
||||
h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.responses.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Record usage
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("gateway.responses.record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// responsesErrorResponse writes an error in OpenAI Responses API format.
|
||||
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
|
||||
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
if streamStarted {
|
||||
return // Can't write error after stream started
|
||||
}
|
||||
statusCode := http.StatusBadGateway
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
|
||||
@ -27,6 +27,9 @@ const (
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
opsUpstreamModelKey = "ops_upstream_model"
|
||||
opsRequestTypeKey = "ops_request_type"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
}
|
||||
}
|
||||
|
||||
// setOpsEndpointContext stores upstream model and request type for ops error logging.
|
||||
// Called by handlers after model mapping and request type determination.
|
||||
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
|
||||
c.Set(opsUpstreamModelKey, upstreamModel)
|
||||
}
|
||||
c.Set(opsRequestTypeKey, requestType)
|
||||
}
|
||||
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: "upstream",
|
||||
@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: stream,
|
||||
Stream: stream,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
|
||||
RequestedModel: modelName,
|
||||
UpstreamModel: func() string {
|
||||
if v, ok := c.Get(opsUpstreamModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
RequestType: func() *int16 {
|
||||
if v, ok := c.Get(opsRequestTypeKey); ok {
|
||||
switch t := v.(type) {
|
||||
case int16:
|
||||
return &t
|
||||
case int:
|
||||
v16 := int16(t)
|
||||
return &v16
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: phase,
|
||||
|
||||
@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
|
||||
|
||||
v, ok := c.Get(opsUpstreamModelKey)
|
||||
require.True(t, ok)
|
||||
vStr, ok := v.(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
|
||||
|
||||
rt, ok := c.Get(opsRequestTypeKey)
|
||||
require.True(t, ok)
|
||||
rtVal, ok := rt.(int16)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(2), rtVal)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
setOpsEndpointContext(c, "", int16(1))
|
||||
|
||||
_, ok := c.Get(opsUpstreamModelKey)
|
||||
require.False(t, ok, "empty upstream model should not be stored")
|
||||
|
||||
rt, ok := c.Get(opsRequestTypeKey)
|
||||
require.True(t, ok)
|
||||
rtVal, ok := rt.(int16)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int16(1), rtVal)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
|
||||
require.NotPanics(t, func() {
|
||||
setOpsEndpointContext(nil, "model", int16(1))
|
||||
})
|
||||
}
|
||||
|
||||
@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error
|
||||
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
|
||||
|
||||
@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
|
||||
|
||||
platform := ""
|
||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
|
||||
@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
|
||||
@ -0,0 +1,521 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: AnthropicResponse → ResponsesResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicToResponsesResponse converts an Anthropic Messages response into a
|
||||
// Responses API response. This is the reverse of ResponsesToAnthropic and
|
||||
// enables Anthropic upstream responses to be returned in OpenAI Responses format.
|
||||
func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
|
||||
id := resp.ID
|
||||
if id == "" {
|
||||
id = generateResponsesID()
|
||||
}
|
||||
|
||||
out := &ResponsesResponse{
|
||||
ID: id,
|
||||
Object: "response",
|
||||
Model: resp.Model,
|
||||
}
|
||||
|
||||
var outputs []ResponsesOutput
|
||||
var msgParts []ResponsesContentPart
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "thinking":
|
||||
if block.Thinking != "" {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "reasoning",
|
||||
ID: generateItemID(),
|
||||
Summary: []ResponsesSummary{{
|
||||
Type: "summary_text",
|
||||
Text: block.Thinking,
|
||||
}},
|
||||
})
|
||||
}
|
||||
case "text":
|
||||
if block.Text != "" {
|
||||
msgParts = append(msgParts, ResponsesContentPart{
|
||||
Type: "output_text",
|
||||
Text: block.Text,
|
||||
})
|
||||
}
|
||||
case "tool_use":
|
||||
args := "{}"
|
||||
if len(block.Input) > 0 {
|
||||
args = string(block.Input)
|
||||
}
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "function_call",
|
||||
ID: generateItemID(),
|
||||
CallID: toResponsesCallID(block.ID),
|
||||
Name: block.Name,
|
||||
Arguments: args,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble message output item from text parts
|
||||
if len(msgParts) > 0 {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: generateItemID(),
|
||||
Role: "assistant",
|
||||
Content: msgParts,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
if len(outputs) == 0 {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: generateItemID(),
|
||||
Role: "assistant",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
out.Output = outputs
|
||||
|
||||
// Map stop_reason → status
|
||||
out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content)
|
||||
if out.Status == "incomplete" {
|
||||
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
|
||||
}
|
||||
|
||||
// Usage
|
||||
out.Usage = &ResponsesUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadInputTokens > 0 {
|
||||
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||
CachedTokens: resp.Usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status.
|
||||
func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string {
|
||||
switch stopReason {
|
||||
case "max_tokens":
|
||||
return "incomplete"
|
||||
case "end_turn", "tool_use", "stop_sequence":
|
||||
return "completed"
|
||||
default:
|
||||
return "completed"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// AnthropicEventToResponsesState tracks state for converting a sequence of
|
||||
// Anthropic SSE events into Responses SSE events.
|
||||
type AnthropicEventToResponsesState struct {
|
||||
ResponseID string
|
||||
Model string
|
||||
Created int64
|
||||
SequenceNumber int
|
||||
|
||||
// CreatedSent tracks whether response.created has been emitted.
|
||||
CreatedSent bool
|
||||
// CompletedSent tracks whether the terminal event has been emitted.
|
||||
CompletedSent bool
|
||||
|
||||
// Current output tracking
|
||||
OutputIndex int
|
||||
CurrentItemID string
|
||||
CurrentItemType string // "message" | "function_call" | "reasoning"
|
||||
|
||||
// For message output: accumulate text parts
|
||||
ContentIndex int
|
||||
|
||||
// For function_call: track per-output info
|
||||
CurrentCallID string
|
||||
CurrentName string
|
||||
|
||||
// Usage from message_delta
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheReadInputTokens int
|
||||
}
|
||||
|
||||
// NewAnthropicEventToResponsesState returns an initialised stream state.
|
||||
func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState {
|
||||
return &AnthropicEventToResponsesState{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into
|
||||
// zero or more Responses SSE events, updating state as it goes.
|
||||
func AnthropicEventToResponsesEvents(
|
||||
evt *AnthropicStreamEvent,
|
||||
state *AnthropicEventToResponsesState,
|
||||
) []ResponsesStreamEvent {
|
||||
switch evt.Type {
|
||||
case "message_start":
|
||||
return anthToResHandleMessageStart(evt, state)
|
||||
case "content_block_start":
|
||||
return anthToResHandleContentBlockStart(evt, state)
|
||||
case "content_block_delta":
|
||||
return anthToResHandleContentBlockDelta(evt, state)
|
||||
case "content_block_stop":
|
||||
return anthToResHandleContentBlockStop(evt, state)
|
||||
case "message_delta":
|
||||
return anthToResHandleMessageDelta(evt, state)
|
||||
case "message_stop":
|
||||
return anthToResHandleMessageStop(state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeAnthropicResponsesStream emits synthetic termination events if the
|
||||
// stream ended without a proper message_stop.
|
||||
func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if !state.CreatedSent || state.CompletedSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []ResponsesStreamEvent
|
||||
|
||||
// Close any open item
|
||||
events = append(events, closeCurrentResponsesItem(state)...)
|
||||
|
||||
// Emit response.completed
|
||||
events = append(events, makeResponsesCompletedEvent(state, "completed", nil))
|
||||
state.CompletedSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line.
|
||||
func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) {
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if evt.Message != nil {
|
||||
state.ResponseID = evt.Message.ID
|
||||
if state.Model == "" {
|
||||
state.Model = evt.Message.Model
|
||||
}
|
||||
if evt.Message.Usage.InputTokens > 0 {
|
||||
state.InputTokens = evt.Message.Usage.InputTokens
|
||||
}
|
||||
}
|
||||
|
||||
if state.CreatedSent {
|
||||
return nil
|
||||
}
|
||||
state.CreatedSent = true
|
||||
|
||||
// Emit response.created
|
||||
return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)}
|
||||
}
|
||||
|
||||
func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if evt.ContentBlock == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []ResponsesStreamEvent
|
||||
|
||||
switch evt.ContentBlock.Type {
|
||||
case "thinking":
|
||||
state.CurrentItemID = generateItemID()
|
||||
state.CurrentItemType = "reasoning"
|
||||
state.ContentIndex = 0
|
||||
|
||||
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "reasoning",
|
||||
ID: state.CurrentItemID,
|
||||
},
|
||||
}))
|
||||
|
||||
case "text":
|
||||
// If we don't have an open message item, open one
|
||||
if state.CurrentItemType != "message" {
|
||||
state.CurrentItemID = generateItemID()
|
||||
state.CurrentItemType = "message"
|
||||
state.ContentIndex = 0
|
||||
|
||||
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: state.CurrentItemID,
|
||||
Role: "assistant",
|
||||
Status: "in_progress",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
// Close previous item if any
|
||||
events = append(events, closeCurrentResponsesItem(state)...)
|
||||
|
||||
state.CurrentItemID = generateItemID()
|
||||
state.CurrentItemType = "function_call"
|
||||
state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID)
|
||||
state.CurrentName = evt.ContentBlock.Name
|
||||
|
||||
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
ID: state.CurrentItemID,
|
||||
CallID: state.CurrentCallID,
|
||||
Name: state.CurrentName,
|
||||
Status: "in_progress",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if evt.Delta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch evt.Delta.Type {
|
||||
case "text_delta":
|
||||
if evt.Delta.Text == "" {
|
||||
return nil
|
||||
}
|
||||
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
ContentIndex: state.ContentIndex,
|
||||
Delta: evt.Delta.Text,
|
||||
ItemID: state.CurrentItemID,
|
||||
})}
|
||||
|
||||
case "thinking_delta":
|
||||
if evt.Delta.Thinking == "" {
|
||||
return nil
|
||||
}
|
||||
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
SummaryIndex: 0,
|
||||
Delta: evt.Delta.Thinking,
|
||||
ItemID: state.CurrentItemID,
|
||||
})}
|
||||
|
||||
case "input_json_delta":
|
||||
if evt.Delta.PartialJSON == "" {
|
||||
return nil
|
||||
}
|
||||
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
Delta: evt.Delta.PartialJSON,
|
||||
ItemID: state.CurrentItemID,
|
||||
CallID: state.CurrentCallID,
|
||||
Name: state.CurrentName,
|
||||
})}
|
||||
|
||||
case "signature_delta":
|
||||
// Anthropic signature deltas have no Responses equivalent; skip
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
switch state.CurrentItemType {
|
||||
case "reasoning":
|
||||
// Emit reasoning summary done + output item done
|
||||
events := []ResponsesStreamEvent{
|
||||
makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
SummaryIndex: 0,
|
||||
ItemID: state.CurrentItemID,
|
||||
}),
|
||||
}
|
||||
events = append(events, closeCurrentResponsesItem(state)...)
|
||||
return events
|
||||
|
||||
case "function_call":
|
||||
// Emit function_call_arguments.done + output item done
|
||||
events := []ResponsesStreamEvent{
|
||||
makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
ItemID: state.CurrentItemID,
|
||||
CallID: state.CurrentCallID,
|
||||
Name: state.CurrentName,
|
||||
}),
|
||||
}
|
||||
events = append(events, closeCurrentResponsesItem(state)...)
|
||||
return events
|
||||
|
||||
case "message":
|
||||
// Emit output_text.done (text block is done, but message item stays open for potential more blocks)
|
||||
return []ResponsesStreamEvent{
|
||||
makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex,
|
||||
ContentIndex: state.ContentIndex,
|
||||
ItemID: state.CurrentItemID,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
// Update usage
|
||||
if evt.Usage != nil {
|
||||
state.OutputTokens = evt.Usage.OutputTokens
|
||||
if evt.Usage.CacheReadInputTokens > 0 {
|
||||
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if state.CompletedSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []ResponsesStreamEvent
|
||||
|
||||
// Close any open item
|
||||
events = append(events, closeCurrentResponsesItem(state)...)
|
||||
|
||||
// Determine status
|
||||
status := "completed"
|
||||
var incompleteDetails *ResponsesIncompleteDetails
|
||||
|
||||
// Emit response.completed
|
||||
events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails))
|
||||
state.CompletedSent = true
|
||||
return events
|
||||
}
|
||||
|
||||
// --- helper functions ---
|
||||
|
||||
func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
|
||||
if state.CurrentItemType == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
itemType := state.CurrentItemType
|
||||
itemID := state.CurrentItemID
|
||||
|
||||
// Reset
|
||||
state.CurrentItemType = ""
|
||||
state.CurrentItemID = ""
|
||||
state.CurrentCallID = ""
|
||||
state.CurrentName = ""
|
||||
state.OutputIndex++
|
||||
state.ContentIndex = 0
|
||||
|
||||
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
|
||||
OutputIndex: state.OutputIndex - 1, // Use the index before increment
|
||||
Item: &ResponsesOutput{
|
||||
Type: itemType,
|
||||
ID: itemID,
|
||||
Status: "completed",
|
||||
},
|
||||
})}
|
||||
}
|
||||
|
||||
func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent {
|
||||
seq := state.SequenceNumber
|
||||
state.SequenceNumber++
|
||||
return ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
SequenceNumber: seq,
|
||||
Response: &ResponsesResponse{
|
||||
ID: state.ResponseID,
|
||||
Object: "response",
|
||||
Model: state.Model,
|
||||
Status: "in_progress",
|
||||
Output: []ResponsesOutput{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeResponsesCompletedEvent(
|
||||
state *AnthropicEventToResponsesState,
|
||||
status string,
|
||||
incompleteDetails *ResponsesIncompleteDetails,
|
||||
) ResponsesStreamEvent {
|
||||
seq := state.SequenceNumber
|
||||
state.SequenceNumber++
|
||||
|
||||
usage := &ResponsesUsage{
|
||||
InputTokens: state.InputTokens,
|
||||
OutputTokens: state.OutputTokens,
|
||||
TotalTokens: state.InputTokens + state.OutputTokens,
|
||||
}
|
||||
if state.CacheReadInputTokens > 0 {
|
||||
usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||
CachedTokens: state.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
SequenceNumber: seq,
|
||||
Response: &ResponsesResponse{
|
||||
ID: state.ResponseID,
|
||||
Object: "response",
|
||||
Model: state.Model,
|
||||
Status: status,
|
||||
Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity
|
||||
Usage: usage,
|
||||
IncompleteDetails: incompleteDetails,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent {
|
||||
seq := state.SequenceNumber
|
||||
state.SequenceNumber++
|
||||
|
||||
evt := *template
|
||||
evt.Type = eventType
|
||||
evt.SequenceNumber = seq
|
||||
return evt
|
||||
}
|
||||
|
||||
func generateResponsesID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "resp_" + hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func generateItemID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "item_" + hex.EncodeToString(b)
|
||||
}
|
||||
464
backend/internal/pkg/apicompat/responses_to_anthropic_request.go
Normal file
464
backend/internal/pkg/apicompat/responses_to_anthropic_request.go
Normal file
@ -0,0 +1,464 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ResponsesToAnthropicRequest converts a Responses API request into an
|
||||
// Anthropic Messages request. This is the reverse of AnthropicToResponses and
|
||||
// enables Anthropic platform groups to accept OpenAI Responses API requests
|
||||
// by converting them to the native /v1/messages format before forwarding upstream.
|
||||
func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) {
|
||||
system, messages, err := convertResponsesInputToAnthropic(req.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &AnthropicRequest{
|
||||
Model: req.Model,
|
||||
Messages: messages,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
out.System = system
|
||||
}
|
||||
|
||||
// max_output_tokens → max_tokens
|
||||
if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 {
|
||||
out.MaxTokens = *req.MaxOutputTokens
|
||||
}
|
||||
if out.MaxTokens == 0 {
|
||||
// Anthropic requires max_tokens; default to a sensible value.
|
||||
out.MaxTokens = 8192
|
||||
}
|
||||
|
||||
// Convert tools
|
||||
if len(req.Tools) > 0 {
|
||||
out.Tools = convertResponsesToAnthropicTools(req.Tools)
|
||||
}
|
||||
|
||||
// Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses)
|
||||
if len(req.ToolChoice) > 0 {
|
||||
tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert tool_choice: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
// reasoning.effort → output_config.effort + thinking
|
||||
if req.Reasoning != nil && req.Reasoning.Effort != "" {
|
||||
effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort)
|
||||
out.OutputConfig = &AnthropicOutputConfig{Effort: effort}
|
||||
// Enable thinking for non-low efforts
|
||||
if effort != "low" {
|
||||
out.Thinking = &AnthropicThinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: defaultThinkingBudget(effort),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// defaultThinkingBudget returns a sensible thinking budget based on effort level.
|
||||
func defaultThinkingBudget(effort string) int {
|
||||
switch effort {
|
||||
case "low":
|
||||
return 1024
|
||||
case "medium":
|
||||
return 4096
|
||||
case "high":
|
||||
return 10240
|
||||
case "max":
|
||||
return 32768
|
||||
default:
|
||||
return 10240
|
||||
}
|
||||
}
|
||||
|
||||
// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to
|
||||
// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses.
|
||||
//
|
||||
// low → low
|
||||
// medium → medium
|
||||
// high → high
|
||||
// xhigh → max
|
||||
func mapResponsesEffortToAnthropic(effort string) string {
|
||||
if effort == "xhigh" {
|
||||
return "max"
|
||||
}
|
||||
return effort // low→low, medium→medium, high→high, unknown→passthrough
|
||||
}
|
||||
|
||||
// convertResponsesInputToAnthropic extracts system prompt and messages from
|
||||
// a Responses API input array. Returns the system as raw JSON (for Anthropic's
|
||||
// polymorphic system field) and a list of Anthropic messages.
|
||||
func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) {
|
||||
// Try as plain string input.
|
||||
var inputStr string
|
||||
if err := json.Unmarshal(inputRaw, &inputStr); err == nil {
|
||||
content, _ := json.Marshal(inputStr)
|
||||
return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var items []ResponsesInputItem
|
||||
if err := json.Unmarshal(inputRaw, &items); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse responses input: %w", err)
|
||||
}
|
||||
|
||||
var system json.RawMessage
|
||||
var messages []AnthropicMessage
|
||||
|
||||
for _, item := range items {
|
||||
switch {
|
||||
case item.Role == "system":
|
||||
// System prompt → Anthropic system field
|
||||
text := extractTextFromContent(item.Content)
|
||||
if text != "" {
|
||||
system, _ = json.Marshal(text)
|
||||
}
|
||||
|
||||
case item.Type == "function_call":
|
||||
// function_call → assistant message with tool_use block
|
||||
input := json.RawMessage("{}")
|
||||
if item.Arguments != "" {
|
||||
input = json.RawMessage(item.Arguments)
|
||||
}
|
||||
block := AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: fromResponsesCallIDToAnthropic(item.CallID),
|
||||
Name: item.Name,
|
||||
Input: input,
|
||||
}
|
||||
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
|
||||
messages = append(messages, AnthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: blockJSON,
|
||||
})
|
||||
|
||||
case item.Type == "function_call_output":
|
||||
// function_call_output → user message with tool_result block
|
||||
outputContent := item.Output
|
||||
if outputContent == "" {
|
||||
outputContent = "(empty)"
|
||||
}
|
||||
contentJSON, _ := json.Marshal(outputContent)
|
||||
block := AnthropicContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: fromResponsesCallIDToAnthropic(item.CallID),
|
||||
Content: contentJSON,
|
||||
}
|
||||
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
|
||||
messages = append(messages, AnthropicMessage{
|
||||
Role: "user",
|
||||
Content: blockJSON,
|
||||
})
|
||||
|
||||
case item.Role == "user":
|
||||
content, err := convertResponsesUserToAnthropicContent(item.Content)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append(messages, AnthropicMessage{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
|
||||
case item.Role == "assistant":
|
||||
content, err := convertResponsesAssistantToAnthropicContent(item.Content)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append(messages, AnthropicMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
|
||||
default:
|
||||
// Unknown role/type — attempt as user message
|
||||
if item.Content != nil {
|
||||
messages = append(messages, AnthropicMessage{
|
||||
Role: "user",
|
||||
Content: item.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||
messages = mergeConsecutiveMessages(messages)
|
||||
|
||||
return system, messages, nil
|
||||
}
|
||||
|
||||
// extractTextFromContent extracts text from a content field that may be a
|
||||
// plain string or an array of content parts.
|
||||
func extractTextFromContent(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
var parts []ResponsesContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err == nil {
|
||||
var texts []string
|
||||
for _, p := range parts {
|
||||
if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" {
|
||||
texts = append(texts, p.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// convertResponsesUserToAnthropicContent converts a Responses user message
|
||||
// content field into Anthropic content blocks JSON.
|
||||
func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 {
|
||||
return json.Marshal("") // empty string content
|
||||
}
|
||||
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Array of content parts → Anthropic content blocks.
|
||||
var parts []ResponsesContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Pass through as-is if we can't parse
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "input_text", "text":
|
||||
if p.Text != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "input_image":
|
||||
src := dataURIToAnthropicImageSource(p.ImageURL)
|
||||
if src != nil {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "image",
|
||||
Source: src,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
return json.Marshal("")
|
||||
}
|
||||
return json.Marshal(blocks)
|
||||
}
|
||||
|
||||
// convertResponsesAssistantToAnthropicContent converts a Responses assistant
|
||||
// message content field into Anthropic content blocks JSON.
|
||||
func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 {
|
||||
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}})
|
||||
}
|
||||
|
||||
// Try plain string.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}})
|
||||
}
|
||||
|
||||
// Array of content parts → Anthropic content blocks.
|
||||
var parts []ResponsesContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
var blocks []AnthropicContentBlock
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "output_text", "text":
|
||||
if p.Text != "" {
|
||||
blocks = append(blocks, AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
|
||||
}
|
||||
return json.Marshal(blocks)
|
||||
}
|
||||
|
||||
// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to
|
||||
// Anthropic format. Reverses toResponsesCallID.
|
||||
func fromResponsesCallIDToAnthropic(id string) string {
|
||||
// If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it
|
||||
if after, ok := strings.CutPrefix(id, "fc_"); ok {
|
||||
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
|
||||
return after
|
||||
}
|
||||
}
|
||||
// Generate a synthetic Anthropic tool ID
|
||||
if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") {
|
||||
return "toolu_" + id
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource.
|
||||
func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource {
|
||||
if !strings.HasPrefix(dataURI, "data:") {
|
||||
return nil
|
||||
}
|
||||
// Format: data:<media_type>;base64,<data>
|
||||
rest := strings.TrimPrefix(dataURI, "data:")
|
||||
semicolonIdx := strings.Index(rest, ";")
|
||||
if semicolonIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
mediaType := rest[:semicolonIdx]
|
||||
rest = rest[semicolonIdx+1:]
|
||||
if !strings.HasPrefix(rest, "base64,") {
|
||||
return nil
|
||||
}
|
||||
data := strings.TrimPrefix(rest, "base64,")
|
||||
return &AnthropicImageSource{
|
||||
Type: "base64",
|
||||
MediaType: mediaType,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// mergeConsecutiveMessages merges consecutive messages with the same role
|
||||
// because Anthropic requires alternating user/assistant turns.
|
||||
func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
|
||||
var merged []AnthropicMessage
|
||||
for _, msg := range messages {
|
||||
if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role {
|
||||
merged = append(merged, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Same role — merge content arrays
|
||||
last := &merged[len(merged)-1]
|
||||
lastBlocks := parseContentBlocks(last.Content)
|
||||
newBlocks := parseContentBlocks(msg.Content)
|
||||
combined := append(lastBlocks, newBlocks...)
|
||||
last.Content, _ = json.Marshal(combined)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// parseContentBlocks attempts to parse content as []AnthropicContentBlock.
|
||||
// If it's a string, wraps it in a text block.
|
||||
func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock {
|
||||
var blocks []AnthropicContentBlock
|
||||
if err := json.Unmarshal(raw, &blocks); err == nil {
|
||||
return blocks
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return []AnthropicContentBlock{{Type: "text", Text: s}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format.
|
||||
// Reverse of convertAnthropicToolsToResponses.
|
||||
func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
|
||||
var out []AnthropicTool
|
||||
for _, t := range tools {
|
||||
switch t.Type {
|
||||
case "web_search":
|
||||
out = append(out, AnthropicTool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
})
|
||||
case "function":
|
||||
out = append(out, AnthropicTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: normalizeAnthropicInputSchema(t.Parameters),
|
||||
})
|
||||
default:
|
||||
// Pass through unknown tool types
|
||||
out = append(out, AnthropicTool{
|
||||
Type: t.Type,
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeAnthropicInputSchema ensures the input_schema has a "type" field.
|
||||
func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
|
||||
if len(schema) == 0 || string(schema) == "null" {
|
||||
return json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format.
|
||||
// Reverse of convertAnthropicToolChoiceToResponses.
|
||||
//
|
||||
// "auto" → {"type":"auto"}
|
||||
// "required" → {"type":"any"}
|
||||
// "none" → {"type":"none"}
|
||||
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
|
||||
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try as string first
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
switch s {
|
||||
case "auto":
|
||||
return json.Marshal(map[string]string{"type": "auto"})
|
||||
case "required":
|
||||
return json.Marshal(map[string]string{"type": "any"})
|
||||
case "none":
|
||||
return json.Marshal(map[string]string{"type": "none"})
|
||||
default:
|
||||
return raw, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try as object with type=function
|
||||
var tc struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
|
||||
return json.Marshal(map[string]string{
|
||||
"type": "tool",
|
||||
"name": tc.Function.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Pass through unknown
|
||||
return raw, nil
|
||||
}
|
||||
@ -270,6 +270,7 @@ type OpenAIAuthClaims struct {
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
POID string `json:"poid"` // organization ID in access_token JWT
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
|
||||
|
||||
@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||
_, err := r.client.Account.UpdateOneID(id).
|
||||
SetCredentials(normalizeJSONMap(credentials)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "", 0, "")
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
q := r.client.Account.Query()
|
||||
|
||||
if platform != "" {
|
||||
@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
} else if groupID > 0 {
|
||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||
}
|
||||
if privacyMode != "" {
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
path := sqljson.Path("privacy_mode")
|
||||
switch privacyMode {
|
||||
case service.AccountPrivacyModeUnsetFilter:
|
||||
s.Where(entsql.Or(
|
||||
entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, "", path),
|
||||
))
|
||||
default:
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
|
||||
@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() {
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(client *dbent.Client)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
wantCount int
|
||||
validate func(accounts []service.Account)
|
||||
name string
|
||||
setup func(client *dbent.Client)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
privacyMode string
|
||||
wantCount int
|
||||
validate func(accounts []service.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
s.Require().Empty(accounts[0].GroupIDs)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_privacy_mode",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}})
|
||||
},
|
||||
privacyMode: service.PrivacyModeTrainingOff,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal("privacy-ok", accounts[0].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_privacy_mode_unset",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
|
||||
},
|
||||
privacyMode: service.AccountPrivacyModeUnsetFilter,
|
||||
wantCount: 2,
|
||||
validate: func(accounts []service.Account) {
|
||||
names := []string{accounts[0].Name, accounts[1].Name}
|
||||
s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
|
||||
@ -29,6 +29,11 @@ INSERT INTO ops_error_logs (
|
||||
model,
|
||||
request_path,
|
||||
stream,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
request_type,
|
||||
user_agent,
|
||||
error_phase,
|
||||
error_type,
|
||||
@ -57,7 +62,7 @@ INSERT INTO ops_error_logs (
|
||||
retry_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43
|
||||
)`
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||
opsNullString(input.Model),
|
||||
opsNullString(input.RequestPath),
|
||||
input.Stream,
|
||||
opsNullString(input.InboundEndpoint),
|
||||
opsNullString(input.UpstreamEndpoint),
|
||||
opsNullString(input.RequestedModel),
|
||||
opsNullString(input.UpstreamModel),
|
||||
opsNullInt16(input.RequestType),
|
||||
opsNullString(input.UserAgent),
|
||||
input.ErrorPhase,
|
||||
input.ErrorType,
|
||||
@ -231,7 +241,12 @@ SELECT
|
||||
COALESCE(g.name, ''),
|
||||
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||
COALESCE(e.request_path, ''),
|
||||
e.stream
|
||||
e.stream,
|
||||
COALESCE(e.inbound_endpoint, ''),
|
||||
COALESCE(e.upstream_endpoint, ''),
|
||||
COALESCE(e.requested_model, ''),
|
||||
COALESCE(e.upstream_model, ''),
|
||||
e.request_type
|
||||
FROM ops_error_logs e
|
||||
LEFT JOIN accounts a ON e.account_id = a.id
|
||||
LEFT JOIN groups g ON e.group_id = g.id
|
||||
@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||
var resolvedBy sql.NullInt64
|
||||
var resolvedByName string
|
||||
var resolvedRetryID sql.NullInt64
|
||||
var requestType sql.NullInt64
|
||||
if err := rows.Scan(
|
||||
&item.ID,
|
||||
&item.CreatedAt,
|
||||
@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||
&clientIP,
|
||||
&item.RequestPath,
|
||||
&item.Stream,
|
||||
&item.InboundEndpoint,
|
||||
&item.UpstreamEndpoint,
|
||||
&item.RequestedModel,
|
||||
&item.UpstreamModel,
|
||||
&requestType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||
item.GroupID = &v
|
||||
}
|
||||
item.GroupName = groupName
|
||||
if requestType.Valid {
|
||||
v := int16(requestType.Int64)
|
||||
item.RequestType = &v
|
||||
}
|
||||
out = append(out, &item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -393,6 +418,11 @@ SELECT
|
||||
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||
COALESCE(e.request_path, ''),
|
||||
e.stream,
|
||||
COALESCE(e.inbound_endpoint, ''),
|
||||
COALESCE(e.upstream_endpoint, ''),
|
||||
COALESCE(e.requested_model, ''),
|
||||
COALESCE(e.upstream_model, ''),
|
||||
e.request_type,
|
||||
COALESCE(e.user_agent, ''),
|
||||
e.auth_latency_ms,
|
||||
e.routing_latency_ms,
|
||||
@ -427,6 +457,7 @@ LIMIT 1`
|
||||
var responseLatency sql.NullInt64
|
||||
var ttft sql.NullInt64
|
||||
var requestBodyBytes sql.NullInt64
|
||||
var requestType sql.NullInt64
|
||||
|
||||
err := r.db.QueryRowContext(ctx, q, id).Scan(
|
||||
&out.ID,
|
||||
@ -464,6 +495,11 @@ LIMIT 1`
|
||||
&clientIP,
|
||||
&out.RequestPath,
|
||||
&out.Stream,
|
||||
&out.InboundEndpoint,
|
||||
&out.UpstreamEndpoint,
|
||||
&out.RequestedModel,
|
||||
&out.UpstreamModel,
|
||||
&requestType,
|
||||
&out.UserAgent,
|
||||
&authLatency,
|
||||
&routingLatency,
|
||||
@ -540,6 +576,10 @@ LIMIT 1`
|
||||
v := int(requestBodyBytes.Int64)
|
||||
out.RequestBodyBytes = &v
|
||||
}
|
||||
if requestType.Valid {
|
||||
v := int16(requestType.Int64)
|
||||
out.RequestType = &v
|
||||
}
|
||||
|
||||
// Normalize request_body to empty string when stored as JSON null.
|
||||
out.RequestBody = strings.TrimSpace(out.RequestBody)
|
||||
@ -1479,3 +1519,10 @@ func opsNullInt(v any) any {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
}
|
||||
|
||||
func opsNullInt16(v *int16) any {
|
||||
if v == nil {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||||
}
|
||||
|
||||
@ -990,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@ -69,12 +69,30 @@ func RegisterGatewayRoutes(
|
||||
})
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
// OpenAI Responses API: auto-route based on group platform
|
||||
gateway.POST("/responses", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Responses(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Responses(c)
|
||||
})
|
||||
gateway.POST("/responses/*subpath", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Responses(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Responses(c)
|
||||
})
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API
|
||||
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||
// OpenAI Chat Completions API: auto-route based on group platform
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.ChatCompletions(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.ChatCompletions(c)
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@ -92,12 +110,25 @@ func RegisterGatewayRoutes(
|
||||
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
// OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform
|
||||
responsesHandler := func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Responses(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Responses(c)
|
||||
}
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.ChatCompletions(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.ChatCompletions(c)
|
||||
})
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
30
backend/internal/service/account_credentials_persistence.go
Normal file
30
backend/internal/service/account_credentials_persistence.go
Normal file
@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
type accountCredentialsUpdater interface {
|
||||
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
|
||||
}
|
||||
|
||||
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
|
||||
if repo == nil || account == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
account.Credentials = cloneCredentials(credentials)
|
||||
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
|
||||
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
|
||||
}
|
||||
return repo.Update(ctx, account)
|
||||
}
|
||||
|
||||
func cloneCredentials(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -15,6 +15,7 @@ var (
|
||||
)
|
||||
|
||||
const AccountListGroupUngrouped int64 = -1
|
||||
const AccountPrivacyModeUnsetFilter = "__unset__"
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *Account) error
|
||||
@ -37,7 +38,7 @@ type AccountRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ type AdminService interface {
|
||||
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct {
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersPrivacy string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersPrivacy = privacyMode
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
|
||||
t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
|
||||
require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
|
||||
@ -643,6 +643,7 @@ urlFallbackLoop:
|
||||
AccountID: p.account.ID,
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
@ -720,6 +721,7 @@ urlFallbackLoop:
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
@ -754,6 +756,7 @@ urlFallbackLoop:
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
|
||||
@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
|
||||
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||
"account_id", account.ID,
|
||||
"error", updateErr,
|
||||
|
||||
@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if targetType == AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||
}
|
||||
}
|
||||
item.Action = "created"
|
||||
@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if targetType == AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||
}
|
||||
}
|
||||
|
||||
@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
continue
|
||||
}
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"net/smtp"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySMTPHost]
|
||||
host := strings.TrimSpace(settings[SettingKeySMTPHost])
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
Username: strings.TrimSpace(settings[SettingKeySMTPUsername]),
|
||||
Password: strings.TrimSpace(settings[SettingKeySMTPPassword]),
|
||||
From: strings.TrimSpace(settings[SettingKeySMTPFrom]),
|
||||
FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]),
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
485
backend/internal/service/gateway_forward_as_chat_completions.go
Normal file
485
backend/internal/service/gateway_forward_as_chat_completions.go
Normal file
@ -0,0 +1,485 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body,
|
||||
// converts it to Anthropic Messages format (chained via Responses format),
|
||||
// forwards to the Anthropic upstream, and converts the response back to Chat
|
||||
// Completions format. This enables Chat Completions clients to access Anthropic
|
||||
// models through Anthropic platform groups.
|
||||
func (s *GatewayService) ForwardAsChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
parsed *ParsedRequest,
|
||||
) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Chat Completions request
|
||||
var ccReq apicompat.ChatCompletionsRequest
|
||||
if err := json.Unmarshal(body, &ccReq); err != nil {
|
||||
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||||
}
|
||||
originalModel := ccReq.Model
|
||||
clientStream := ccReq.Stream
|
||||
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
|
||||
|
||||
// 2. Convert CC → Responses → Anthropic (chained conversion)
|
||||
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||
}
|
||||
|
||||
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
|
||||
}
|
||||
|
||||
// 3. Force upstream streaming
|
||||
anthropicReq.Stream = true
|
||||
reqStream := true
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
}
|
||||
anthropicReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("gateway forward_as_chat_completions: model mapping applied",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
zap.Bool("client_stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Marshal Anthropic request body
|
||||
anthropicBody, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal anthropic request: %w", err)
|
||||
}
|
||||
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts
|
||||
isClaudeCode := false // CC API is never Claude Code
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
|
||||
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// 13. Extract reasoning effort from CC request body
|
||||
reasoningEffort := extractCCReasoningEffortFromBody(body)
|
||||
|
||||
// 14. Handle normal response
|
||||
// Read Anthropic SSE → convert to Responses events → convert to CC format
|
||||
var result *ForwardResult
|
||||
var handleErr error
|
||||
if clientStream {
|
||||
result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage)
|
||||
} else {
|
||||
result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
|
||||
}
|
||||
|
||||
return result, handleErr
|
||||
}
|
||||
|
||||
// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions
|
||||
// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort)
|
||||
// formats used by OpenAI-compatible clients.
|
||||
func extractCCReasoningEffortFromBody(body []byte) *string {
|
||||
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
|
||||
}
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
normalized := normalizeOpenAIReasoningEffort(raw)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full
|
||||
// response, then converts Anthropic → Responses → Chat Completions.
|
||||
func (s *GatewayService) handleCCBufferedFromAnthropic(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
reasoningEffort *string,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
var finalResp *apicompat.AnthropicResponse
|
||||
var usage ClaudeUsage
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
dataLine := scanner.Text()
|
||||
if !strings.HasPrefix(dataLine, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := dataLine[6:]
|
||||
|
||||
var event apicompat.AnthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// message_start carries the initial response structure and cache usage
|
||||
if event.Type == "message_start" && event.Message != nil {
|
||||
finalResp = event.Message
|
||||
mergeAnthropicUsage(&usage, event.Message.Usage)
|
||||
}
|
||||
|
||||
// message_delta carries final usage and stop_reason
|
||||
if event.Type == "message_delta" {
|
||||
if event.Usage != nil {
|
||||
mergeAnthropicUsage(&usage, *event.Usage)
|
||||
}
|
||||
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
|
||||
finalResp.StopReason = event.Delta.StopReason
|
||||
}
|
||||
}
|
||||
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
|
||||
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
|
||||
}
|
||||
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
|
||||
idx := *event.Index
|
||||
if idx < len(finalResp.Content) {
|
||||
switch event.Delta.Type {
|
||||
case "text_delta":
|
||||
finalResp.Content[idx].Text += event.Delta.Text
|
||||
case "thinking_delta":
|
||||
finalResp.Content[idx].Thinking += event.Delta.Thinking
|
||||
case "input_json_delta":
|
||||
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("forward_as_cc buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if finalResp == nil {
|
||||
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
|
||||
return nil, fmt.Errorf("upstream stream ended without response")
|
||||
}
|
||||
|
||||
// Update usage from accumulated delta
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
finalResp.Usage = apicompat.AnthropicUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Chain: Anthropic → Responses → Chat Completions
|
||||
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
|
||||
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, ccResp)
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each
|
||||
// to Responses events, then to Chat Completions chunks, and writes them.
|
||||
func (s *GatewayService) handleCCStreamingFromAnthropic(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
reasoningEffort *string,
|
||||
startTime time.Time,
|
||||
includeUsage bool,
|
||||
) (*ForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
// Use Anthropic→Responses state machine, then convert Responses→CC
|
||||
anthState := apicompat.NewAnthropicEventToResponsesState()
|
||||
anthState.Model = originalModel
|
||||
ccState := apicompat.NewResponsesEventToChatState()
|
||||
ccState.Model = originalModel
|
||||
ccState.IncludeUsage = includeUsage
|
||||
|
||||
var usage ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
resultWithUsage := func() *ForwardResult {
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}
|
||||
}
|
||||
|
||||
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
|
||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
return true // client disconnected
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// Extract usage from message_delta
|
||||
if event.Type == "message_delta" && event.Usage != nil {
|
||||
mergeAnthropicUsage(&usage, *event.Usage)
|
||||
}
|
||||
// Also capture usage from message_start (carries cache fields)
|
||||
if event.Type == "message_start" && event.Message != nil {
|
||||
mergeAnthropicUsage(&usage, event.Message.Usage)
|
||||
}
|
||||
|
||||
// Chain: Anthropic event → Responses events → CC chunks
|
||||
responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState)
|
||||
for _, resEvt := range responsesEvents {
|
||||
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
|
||||
for _, chunk := range ccChunks {
|
||||
if disconnected := writeChunk(chunk); disconnected {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return false
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
dataLine := scanner.Text()
|
||||
if !strings.HasPrefix(dataLine, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := dataLine[6:]
|
||||
|
||||
var event apicompat.AnthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if processAnthropicEvent(&event) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("forward_as_cc stream: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize both state machines
|
||||
finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState)
|
||||
for _, resEvt := range finalResEvents {
|
||||
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
|
||||
for _, chunk := range ccChunks {
|
||||
writeChunk(chunk) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState)
|
||||
for _, chunk := range finalCCChunks {
|
||||
writeChunk(chunk) //nolint:errcheck
|
||||
}
|
||||
|
||||
// Write [DONE] marker
|
||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||||
c.Writer.Flush()
|
||||
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
// writeGatewayCCError writes an error in OpenAI Chat Completions format for
|
||||
// the Anthropic-upstream CC forwarding path.
|
||||
func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,109 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractCCReasoningEffortFromBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nested reasoning.effort", func(t *testing.T) {
|
||||
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`))
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "high", *got)
|
||||
})
|
||||
|
||||
t.Run("flat reasoning_effort", func(t *testing.T) {
|
||||
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`))
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "xhigh", *got)
|
||||
})
|
||||
|
||||
t.Run("missing effort", func(t *testing.T) {
|
||||
require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
reasoningEffort := "high"
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
|
||||
``,
|
||||
`event: content_block_start`,
|
||||
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
|
||||
``,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
|
||||
``,
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
svc := &GatewayService{}
|
||||
result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 12, result.Usage.InputTokens)
|
||||
require.Equal(t, 7, result.Usage.OutputTokens)
|
||||
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "high", *result.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
reasoningEffort := "medium"
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"x-request-id": []string{"rid_cc_stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
|
||||
``,
|
||||
`event: content_block_start`,
|
||||
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
|
||||
``,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
|
||||
``,
|
||||
`event: message_stop`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
svc := &GatewayService{}
|
||||
result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 20, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "medium", *result.ReasoningEffort)
|
||||
require.Contains(t, rec.Body.String(), `[DONE]`)
|
||||
}
|
||||
518
backend/internal/service/gateway_forward_as_responses.go
Normal file
518
backend/internal/service/gateway_forward_as_responses.go
Normal file
@ -0,0 +1,518 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ForwardAsResponses accepts an OpenAI Responses API request body, converts it
|
||||
// to Anthropic Messages format, forwards to the Anthropic upstream, and converts
|
||||
// the response back to Responses format. This enables OpenAI Responses API
|
||||
// clients to access Anthropic models through Anthropic platform groups.
|
||||
//
|
||||
// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic
|
||||
// but in reverse direction: Responses → Anthropic upstream → Responses.
|
||||
func (s *GatewayService) ForwardAsResponses(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
parsed *ParsedRequest,
|
||||
) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Parse Responses request
|
||||
var responsesReq apicompat.ResponsesRequest
|
||||
if err := json.Unmarshal(body, &responsesReq); err != nil {
|
||||
return nil, fmt.Errorf("parse responses request: %w", err)
|
||||
}
|
||||
originalModel := responsesReq.Model
|
||||
clientStream := responsesReq.Stream
|
||||
|
||||
// 2. Convert Responses → Anthropic
|
||||
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
|
||||
}
|
||||
|
||||
// 3. Force upstream streaming (Anthropic works best with streaming)
|
||||
anthropicReq.Stream = true
|
||||
reqStream := true
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
}
|
||||
anthropicReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("gateway forward_as_responses: model mapping applied",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
zap.Bool("client_stream", clientStream),
|
||||
)
|
||||
|
||||
// 5. Marshal Anthropic request body
|
||||
anthropicBody, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal anthropic request: %w", err)
|
||||
}
|
||||
|
||||
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
|
||||
isClaudeCode := false // Responses API is never Claude Code
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
|
||||
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Enforce cache_control block limit
|
||||
anthropicBody = enforceCacheControlLimit(anthropicBody)
|
||||
|
||||
// 8. Get access token
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// 9. Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 10. Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 11. Send request
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 12. Handle error response with failover
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
// Non-failover error: return Responses-formatted error to client
|
||||
writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
|
||||
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// 13. Handle normal response (convert Anthropic → Responses)
|
||||
var result *ForwardResult
|
||||
var handleErr error
|
||||
if clientStream {
|
||||
result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
|
||||
} else {
|
||||
result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
|
||||
}
|
||||
|
||||
return result, handleErr
|
||||
}
|
||||
|
||||
// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort
|
||||
// and normalizes it for usage logging.
|
||||
func ExtractResponsesReasoningEffortFromBody(body []byte) *string {
|
||||
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
normalized := normalizeOpenAIReasoningEffort(raw)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) {
|
||||
if dst == nil {
|
||||
return
|
||||
}
|
||||
if src.InputTokens > 0 {
|
||||
dst.InputTokens = src.InputTokens
|
||||
}
|
||||
if src.OutputTokens > 0 {
|
||||
dst.OutputTokens = src.OutputTokens
|
||||
}
|
||||
if src.CacheReadInputTokens > 0 {
|
||||
dst.CacheReadInputTokens = src.CacheReadInputTokens
|
||||
}
|
||||
if src.CacheCreationInputTokens > 0 {
|
||||
dst.CacheCreationInputTokens = src.CacheCreationInputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from
|
||||
// the upstream streaming response, assembles them into a complete Anthropic
|
||||
// response, converts to Responses API JSON format, and writes it to the client.
|
||||
func (s *GatewayService) handleResponsesBufferedStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
reasoningEffort *string,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
// Accumulate the final Anthropic response from streaming events
|
||||
var finalResp *apicompat.AnthropicResponse
|
||||
var usage ClaudeUsage
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
eventType := strings.TrimPrefix(line, "event: ")
|
||||
|
||||
// Read the data line
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
dataLine := scanner.Text()
|
||||
if !strings.HasPrefix(dataLine, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := dataLine[6:]
|
||||
|
||||
var event apicompat.AnthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("forward_as_responses buffered: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("event_type", eventType),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// message_start carries the initial response structure
|
||||
if event.Type == "message_start" && event.Message != nil {
|
||||
finalResp = event.Message
|
||||
mergeAnthropicUsage(&usage, event.Message.Usage)
|
||||
}
|
||||
|
||||
// message_delta carries final usage and stop_reason
|
||||
if event.Type == "message_delta" {
|
||||
if event.Usage != nil {
|
||||
mergeAnthropicUsage(&usage, *event.Usage)
|
||||
}
|
||||
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
|
||||
finalResp.StopReason = event.Delta.StopReason
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate content blocks
|
||||
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
|
||||
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
|
||||
}
|
||||
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
|
||||
idx := *event.Index
|
||||
if idx < len(finalResp.Content) {
|
||||
switch event.Delta.Type {
|
||||
case "text_delta":
|
||||
finalResp.Content[idx].Text += event.Delta.Text
|
||||
case "thinking_delta":
|
||||
finalResp.Content[idx].Thinking += event.Delta.Thinking
|
||||
case "input_json_delta":
|
||||
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("forward_as_responses buffered: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if finalResp == nil {
|
||||
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
|
||||
return nil, fmt.Errorf("upstream stream ended without response")
|
||||
}
|
||||
|
||||
// Update usage from accumulated delta
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
finalResp.Usage = apicompat.AnthropicUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Responses format
|
||||
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
|
||||
responsesResp.Model = originalModel // Use original model name
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, responsesResp)
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleResponsesStreamingResponse reads Anthropic SSE events from upstream,
|
||||
// converts each to Responses SSE events, and writes them to the client.
|
||||
func (s *GatewayService) handleResponsesStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
reasoningEffort *string,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
state := apicompat.NewAnthropicEventToResponsesState()
|
||||
state.Model = originalModel
|
||||
var usage ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
firstChunk := true
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
resultWithUsage := func() *ForwardResult {
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}
|
||||
}
|
||||
|
||||
// processEvent handles a single parsed Anthropic SSE event.
|
||||
processEvent := func(event *apicompat.AnthropicStreamEvent) bool {
|
||||
if firstChunk {
|
||||
firstChunk = false
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// Extract usage from message_delta
|
||||
if event.Type == "message_delta" && event.Usage != nil {
|
||||
mergeAnthropicUsage(&usage, *event.Usage)
|
||||
}
|
||||
// Also capture usage from message_start
|
||||
if event.Type == "message_start" && event.Message != nil {
|
||||
mergeAnthropicUsage(&usage, event.Message.Usage)
|
||||
}
|
||||
|
||||
// Convert to Responses events
|
||||
events := apicompat.AnthropicEventToResponsesEvents(event, state)
|
||||
for _, evt := range events {
|
||||
sse, err := apicompat.ResponsesEventToSSE(evt)
|
||||
if err != nil {
|
||||
logger.L().Warn("forward_as_responses stream: failed to marshal event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
logger.L().Info("forward_as_responses stream: client disconnected",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return true // client disconnected
|
||||
}
|
||||
}
|
||||
if len(events) > 0 {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
finalizeStream := func() (*ForwardResult, error) {
|
||||
if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 {
|
||||
for _, evt := range finalEvents {
|
||||
sse, err := apicompat.ResponsesEventToSSE(evt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
|
||||
// Read Anthropic SSE events
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
eventType := strings.TrimPrefix(line, "event: ")
|
||||
|
||||
// Read data line
|
||||
if !scanner.Scan() {
|
||||
break
|
||||
}
|
||||
dataLine := scanner.Text()
|
||||
if !strings.HasPrefix(dataLine, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := dataLine[6:]
|
||||
|
||||
var event apicompat.AnthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
logger.L().Warn("forward_as_responses stream: failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("event_type", eventType),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if processEvent(&event) {
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("forward_as_responses stream: read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return finalizeStream()
|
||||
}
|
||||
|
||||
// appendRawJSON appends a JSON fragment string to existing raw JSON.
|
||||
func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage {
|
||||
if len(existing) == 0 {
|
||||
return json.RawMessage(fragment)
|
||||
}
|
||||
return json.RawMessage(string(existing) + fragment)
|
||||
}
|
||||
|
||||
// writeResponsesError writes an error response in OpenAI Responses API format.
|
||||
func writeResponsesError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes.
|
||||
func mapUpstreamStatusCode(code int) int {
|
||||
if code >= 500 {
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
return code
|
||||
}
|
||||
@ -0,0 +1,94 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractResponsesReasoningEffortFromBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`))
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "high", *got)
|
||||
|
||||
require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`)))
|
||||
}
|
||||
|
||||
func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"x-request-id": []string{"rid_buffered"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
|
||||
``,
|
||||
`event: content_block_start`,
|
||||
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
|
||||
``,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
|
||||
``,
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
svc := &GatewayService{}
|
||||
result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 12, result.Usage.InputTokens)
|
||||
require.Equal(t, 7, result.Usage.OutputTokens)
|
||||
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `"cached_tokens":9`)
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"x-request-id": []string{"rid_stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
|
||||
``,
|
||||
`event: content_block_start`,
|
||||
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
|
||||
``,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
|
||||
``,
|
||||
`event: message_stop`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
svc := &GatewayService{}
|
||||
result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 20, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
|
||||
require.Contains(t, rec.Body.String(), `response.completed`)
|
||||
}
|
||||
@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
@ -34,6 +36,9 @@ var (
|
||||
patternEmptyTextSpaced = []byte(`"text": ""`)
|
||||
patternEmptyTextSp1 = []byte(`"text" : ""`)
|
||||
patternEmptyTextSp2 = []byte(`"text" :""`)
|
||||
|
||||
sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`)
|
||||
sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`)
|
||||
)
|
||||
|
||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||
@ -75,6 +80,49 @@ type ParsedRequest struct {
|
||||
OnUpstreamAccepted func()
|
||||
}
|
||||
|
||||
// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing.
|
||||
// It preserves the set of product names from Product/Version tokens while
|
||||
// discarding version-only changes and incidental comments.
|
||||
func NormalizeSessionUserAgent(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1)
|
||||
if len(matches) == 0 {
|
||||
return normalizeSessionUserAgentFallback(raw)
|
||||
}
|
||||
|
||||
products := make([]string, 0, len(matches))
|
||||
seen := make(map[string]struct{}, len(matches))
|
||||
for _, match := range matches {
|
||||
if len(match) < 2 {
|
||||
continue
|
||||
}
|
||||
product := strings.ToLower(strings.TrimSpace(match[1]))
|
||||
if product == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[product]; exists {
|
||||
continue
|
||||
}
|
||||
seen[product] = struct{}{}
|
||||
products = append(products, product)
|
||||
}
|
||||
if len(products) == 0 {
|
||||
return normalizeSessionUserAgentFallback(raw)
|
||||
}
|
||||
sort.Strings(products)
|
||||
return strings.Join(products, "+")
|
||||
}
|
||||
|
||||
func normalizeSessionUserAgentFallback(raw string) string {
|
||||
normalized := strings.ToLower(strings.Join(strings.Fields(raw), " "))
|
||||
normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "")
|
||||
return strings.Join(strings.Fields(normalized), " ")
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||
// 不同协议使用不同的 system/messages 字段名。
|
||||
@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte {
|
||||
return []byte(r.Raw)
|
||||
}
|
||||
|
||||
// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content).
|
||||
// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged.
|
||||
func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) {
|
||||
var result []any
|
||||
changed := false
|
||||
for i, block := range blocks {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if result != nil {
|
||||
result = append(result, block)
|
||||
}
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
// Strip empty text blocks
|
||||
if blockType == "text" {
|
||||
if txt, _ := blockMap["text"].(string); txt == "" {
|
||||
if result == nil {
|
||||
result = make([]any, 0, len(blocks))
|
||||
result = append(result, blocks[:i]...)
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into tool_result nested content
|
||||
if blockType == "tool_result" {
|
||||
if nestedContent, ok := blockMap["content"].([]any); ok {
|
||||
if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged {
|
||||
if result == nil {
|
||||
result = make([]any, 0, len(blocks))
|
||||
result = append(result, blocks[:i]...)
|
||||
}
|
||||
changed = true
|
||||
blockCopy := make(map[string]any, len(blockMap))
|
||||
for k, v := range blockMap {
|
||||
blockCopy[k] = v
|
||||
}
|
||||
blockCopy["content"] = cleaned
|
||||
result = append(result, blockCopy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
result = append(result, block)
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return blocks, false
|
||||
}
|
||||
return result, true
|
||||
}
|
||||
|
||||
// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content).
|
||||
// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors.
|
||||
// Returns the original body unchanged if no empty text blocks are found.
|
||||
func StripEmptyTextBlocks(body []byte) []byte {
|
||||
// Fast path: check if body contains empty text patterns
|
||||
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
|
||||
bytes.Contains(body, patternEmptyTextSpaced) ||
|
||||
bytes.Contains(body, patternEmptyTextSp1) ||
|
||||
bytes.Contains(body, patternEmptyTextSp2)
|
||||
if !hasEmptyTextBlock {
|
||||
return body
|
||||
}
|
||||
|
||||
jsonStr := *(*string)(unsafe.Pointer(&body))
|
||||
msgsRes := gjson.Get(jsonStr, "messages")
|
||||
if !msgsRes.Exists() || !msgsRes.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
var messages []any
|
||||
if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed {
|
||||
modified = true
|
||||
msgMap["content"] = cleaned
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
|
||||
msgsBytes, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
out, err := sjson.SetRawBytes(body, "messages", msgsBytes)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively strip empty text blocks from tool_result nested content.
|
||||
if blockType == "tool_result" {
|
||||
if nestedContent, ok := blockMap["content"].([]any); ok {
|
||||
if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed {
|
||||
modifiedThisMsg = true
|
||||
ensureNewContent(bi)
|
||||
blockCopy := make(map[string]any, len(blockMap))
|
||||
for k, v := range blockMap {
|
||||
blockCopy[k] = v
|
||||
}
|
||||
blockCopy["content"] = cleaned
|
||||
newContent = append(newContent, blockCopy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if newContent != nil {
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
|
||||
require.NotEmpty(t, block1["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) {
|
||||
// Empty text blocks nested inside tool_result content should also be stripped
|
||||
input := []byte(`{
|
||||
"messages":[
|
||||
{"role":"user","content":[
|
||||
{"type":"tool_result","tool_use_id":"t1","content":[
|
||||
{"type":"text","text":"valid result"},
|
||||
{"type":"text","text":""}
|
||||
]}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
msg0 := msgs[0].(map[string]any)
|
||||
content0 := msg0["content"].([]any)
|
||||
require.Len(t, content0, 1)
|
||||
toolResult := content0[0].(map[string]any)
|
||||
require.Equal(t, "tool_result", toolResult["type"])
|
||||
nestedContent := toolResult["content"].([]any)
|
||||
require.Len(t, nestedContent, 1)
|
||||
require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) {
|
||||
// If all nested content blocks in tool_result are empty text, content becomes empty slice
|
||||
input := []byte(`{
|
||||
"messages":[
|
||||
{"role":"user","content":[
|
||||
{"type":"tool_result","tool_use_id":"t1","content":[
|
||||
{"type":"text","text":""}
|
||||
]},
|
||||
{"type":"text","text":"hello"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
msg0 := msgs[0].(map[string]any)
|
||||
content0 := msg0["content"].([]any)
|
||||
require.Len(t, content0, 2)
|
||||
toolResult := content0[0].(map[string]any)
|
||||
nestedContent := toolResult["content"].([]any)
|
||||
require.Len(t, nestedContent, 0)
|
||||
}
|
||||
|
||||
func TestStripEmptyTextBlocks(t *testing.T) {
|
||||
t.Run("strips top-level empty text", func(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`)
|
||||
out := StripEmptyTextBlocks(input)
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "hello", content[0].(map[string]any)["text"])
|
||||
})
|
||||
|
||||
t.Run("strips nested empty text in tool_result", func(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`)
|
||||
out := StripEmptyTextBlocks(input)
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
toolResult := content[0].(map[string]any)
|
||||
nestedContent := toolResult["content"].([]any)
|
||||
require.Len(t, nestedContent, 1)
|
||||
require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"])
|
||||
})
|
||||
|
||||
t.Run("no-op when no empty text", func(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
out := StripEmptyTextBlocks(input)
|
||||
require.Equal(t, input, out)
|
||||
})
|
||||
|
||||
t.Run("preserves non-map blocks in content", func(t *testing.T) {
|
||||
// tool_result content can be a string; non-map blocks should pass through unchanged
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`)
|
||||
out := StripEmptyTextBlocks(input)
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
require.Len(t, content, 1)
|
||||
toolResult := content[0].(map[string]any)
|
||||
require.Equal(t, "tool_result", toolResult["type"])
|
||||
require.Equal(t, "string content", toolResult["content"])
|
||||
})
|
||||
|
||||
t.Run("handles deeply nested tool_result", func(t *testing.T) {
|
||||
// Recursive: tool_result containing another tool_result with empty text
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`)
|
||||
out := StripEmptyTextBlocks(input)
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs := req["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
outer := content[0].(map[string]any)
|
||||
innerContent := outer["content"].([]any)
|
||||
inner := innerContent[0].(map[string]any)
|
||||
deepContent := inner["content"].([]any)
|
||||
require.Len(t, deepContent, 1)
|
||||
require.Equal(t, "deep", deepContent[0].(map[string]any)["text"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
|
||||
// Non-empty text blocks should pass through unchanged
|
||||
input := []byte(`{
|
||||
|
||||
@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
if parsed.SessionContext != nil {
|
||||
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
||||
_, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent))
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||
_, _ = combined.WriteString("|")
|
||||
@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "signature_error",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(retryReq.URL.String()),
|
||||
Kind: "signature_retry_thinking",
|
||||
Message: extractUpstreamErrorMessage(retryRespBody),
|
||||
Detail: func() string {
|
||||
@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(retryReq2.URL.String()),
|
||||
Kind: "signature_retry_tools_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||
})
|
||||
@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "budget_constraint_error",
|
||||
Message: errMsg,
|
||||
Detail: func() string {
|
||||
@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "retry",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
|
||||
if c != nil {
|
||||
c.Set("anthropic_passthrough", true)
|
||||
}
|
||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||
input.Body = StripEmptyTextBlocks(input.Body)
|
||||
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, input.Body)
|
||||
|
||||
@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Passthrough: true,
|
||||
Kind: "retry",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream(
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream(
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Kind: "retry",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Pre-filter: strip empty text blocks to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
||||
Passthrough: true,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
|
||||
@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string {
|
||||
// 返回 16 字符的 Base64 编码的 SHA256 前缀
|
||||
func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string {
|
||||
// 组合所有标识符
|
||||
normalizedUserAgent := NormalizeSessionUserAgent(userAgent)
|
||||
combined := strconv.FormatInt(userID, 10) + ":" +
|
||||
strconv.FormatInt(apiKeyID, 10) + ":" +
|
||||
ip + ":" +
|
||||
userAgent + ":" +
|
||||
normalizedUserAgent + ":" +
|
||||
platform + ":" +
|
||||
model
|
||||
|
||||
|
||||
@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) {
|
||||
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro")
|
||||
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro")
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) {
|
||||
hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro")
|
||||
hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro")
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGeminiSessionValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if tierID != "" {
|
||||
account.Credentials["tier_id"] = tierID
|
||||
}
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
_ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) {
|
||||
require.NotEqual(t, h1, h2, "different User-Agent should produce different hash")
|
||||
}
|
||||
|
||||
func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
base := func(ua string) *ParsedRequest {
|
||||
return &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "test"},
|
||||
},
|
||||
SessionContext: &SessionContext{
|
||||
ClientIP: "1.1.1.1",
|
||||
UserAgent: ua,
|
||||
APIKeyID: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0"))
|
||||
h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1"))
|
||||
require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash")
|
||||
}
|
||||
|
||||
func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
base := func(ua string) *ParsedRequest {
|
||||
return &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "test"},
|
||||
},
|
||||
SessionContext: &SessionContext{
|
||||
ClientIP: "1.1.1.1",
|
||||
UserAgent: ua,
|
||||
APIKeyID: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0"))
|
||||
h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1"))
|
||||
require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash")
|
||||
}
|
||||
|
||||
func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
|
||||
@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
// 5. 设置版本号 + 更新 DB
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
freshAccount.Credentials = newCredentials
|
||||
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
||||
if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
|
||||
slog.Error("oauth_refresh_update_failed",
|
||||
"account_id", freshAccount.ID,
|
||||
"error", updateErr,
|
||||
|
||||
@ -16,10 +16,11 @@ import (
|
||||
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
|
||||
type refreshAPIAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account // returned by GetByID
|
||||
getByIDErr error
|
||||
updateErr error
|
||||
updateCalls int
|
||||
account *Account // returned by GetByID
|
||||
getByIDErr error
|
||||
updateErr error
|
||||
updateCalls int
|
||||
updateCredentialsCalls int
|
||||
}
|
||||
|
||||
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||
@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
|
||||
return r.updateErr
|
||||
}
|
||||
|
||||
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
|
||||
r.updateCalls++
|
||||
r.updateCredentialsCalls++
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
if r.account == nil || r.account.ID != id {
|
||||
r.account = &Account{ID: id}
|
||||
}
|
||||
r.account.Credentials = cloneCredentials(credentials)
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
|
||||
type refreshAPIExecutorStub struct {
|
||||
needsRefresh bool
|
||||
@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
|
||||
require.Equal(t, "new-token", result.NewCredentials["access_token"])
|
||||
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
|
||||
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||
require.Equal(t, 1, cache.releaseCalls) // lock released
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 1, cache.releaseCalls) // lock released
|
||||
require.Equal(t, 1, executor.refreshCalls)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
|
||||
resetAt := time.Now().Add(45 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
RateLimitResetAt: &resetAt,
|
||||
}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
executor := &refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "safe-token"},
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Refreshed)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.NotNil(t, repo.account.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
|
||||
account := &Account{ID: 2, Platform: PlatformAnthropic}
|
||||
repo := &refreshAPIAccountRepo{account: account}
|
||||
@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "invalid_grant")
|
||||
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
|
||||
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
|
||||
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
|
||||
}
|
||||
|
||||
@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) {
|
||||
|
||||
result := MergeCredentials(old, new)
|
||||
|
||||
require.Equal(t, "new-token", result["access_token"]) // overridden
|
||||
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
|
||||
require.Equal(t, "new-token", result["access_token"]) // overridden
|
||||
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
|
||||
}
|
||||
|
||||
// ========== BuildClaudeAccountCredentials tests ==========
|
||||
|
||||
@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
|
||||
if account == nil {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
|
||||
@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10103)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
snapshotAccounts: []*Account{staleSticky, staleBackup},
|
||||
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
|
||||
}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
schedulerSnapshot: snapshotService,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(33002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10104)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
snapshotAccounts: []*Account{stalePrimary, staleSecondary},
|
||||
accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary},
|
||||
}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
|
||||
cfg: &config.Config{},
|
||||
schedulerSnapshot: snapshotService,
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(34002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
|
||||
@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
return fresh
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
if s.schedulerSnapshot == nil || s.accountRepo == nil {
|
||||
return account
|
||||
}
|
||||
|
||||
latest, err := s.accountRepo.GetByID(ctx, account.ID)
|
||||
if err != nil || latest == nil {
|
||||
return nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now())
|
||||
if !latest.IsSchedulable() || !latest.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return latest
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
if s.rateLimitService != nil {
|
||||
// Passthrough mode preserves the raw upstream error response, but runtime
|
||||
// account state still needs to be updated so sticky routing can stop
|
||||
// reusing a freshly rate-limited account.
|
||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
|
||||
@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
|
||||
require.True(t, arr[len(arr)-1].Passthrough)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"x-request-id": []string{"rid-rate-limit"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
repo := &openAIWSRateLimitSignalRepo{}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: rateSvc,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "usage_limit_reached")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -29,9 +29,10 @@ type soraSessionChunk struct {
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient OpenAIOAuthClient
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient OpenAIOAuthClient
|
||||
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-api(ImpersonateChrome)
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||||
@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli
|
||||
}
|
||||
}
|
||||
|
||||
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
|
||||
// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。
|
||||
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
|
||||
s.privacyClientFactory = factory
|
||||
}
|
||||
|
||||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||||
type OpenAIAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
@ -131,6 +138,7 @@ type OpenAITokenInfo struct {
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
PlanType string `json:"plan_type,omitempty"`
|
||||
PrivacyMode string `json:"privacy_mode,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
tokenInfo.PlanType = userInfo.PlanType
|
||||
}
|
||||
|
||||
// id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全
|
||||
if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
|
||||
// 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号
|
||||
orgID := tokenInfo.OrganizationID
|
||||
if orgID == "" {
|
||||
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
|
||||
orgID = atClaims.OpenAIAuth.POID
|
||||
}
|
||||
}
|
||||
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
|
||||
if tokenInfo.PlanType == "" && info.PlanType != "" {
|
||||
tokenInfo.PlanType = info.PlanType
|
||||
}
|
||||
if tokenInfo.Email == "" && info.Email != "" {
|
||||
tokenInfo.Email = info.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试设置隐私(关闭训练数据共享),best-effort
|
||||
if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil {
|
||||
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto
|
||||
return PrivacyModeTrainingOff
|
||||
}
|
||||
|
||||
// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息
|
||||
type ChatGPTAccountInfo struct {
|
||||
PlanType string
|
||||
Email string
|
||||
}
|
||||
|
||||
const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
|
||||
|
||||
// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
|
||||
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
|
||||
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
|
||||
// Returns nil on any failure (best-effort, non-blocking).
|
||||
func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo {
|
||||
if accessToken == "" || clientFactory == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
slog.Debug("chatgpt_account_check_client_error", "error", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Origin", "https://chatgpt.com").
|
||||
SetHeader("Referer", "https://chatgpt.com/").
|
||||
SetHeader("Accept", "application/json").
|
||||
SetSuccessResult(&result).
|
||||
Get(chatGPTAccountsCheckURL)
|
||||
|
||||
if err != nil {
|
||||
slog.Debug("chatgpt_account_check_request_error", "error", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
|
||||
return nil
|
||||
}
|
||||
|
||||
info := &ChatGPTAccountInfo{}
|
||||
|
||||
accounts, ok := result["accounts"].(map[string]any)
|
||||
if !ok {
|
||||
slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先匹配 orgID 对应的账号(access_token JWT 中的 poid)
|
||||
if orgID != "" {
|
||||
if matched := extractPlanFromAccount(accounts, orgID); matched != "" {
|
||||
info.PlanType = matched
|
||||
}
|
||||
}
|
||||
|
||||
// 未匹配到时,遍历所有账号:优先 is_default,次选非 free
|
||||
if info.PlanType == "" {
|
||||
var defaultPlan, paidPlan, anyPlan string
|
||||
for _, acctRaw := range accounts {
|
||||
acct, ok := acctRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
planType := extractPlanType(acct)
|
||||
if planType == "" {
|
||||
continue
|
||||
}
|
||||
if anyPlan == "" {
|
||||
anyPlan = planType
|
||||
}
|
||||
if account, ok := acct["account"].(map[string]any); ok {
|
||||
if isDefault, _ := account["is_default"].(bool); isDefault {
|
||||
defaultPlan = planType
|
||||
}
|
||||
}
|
||||
if !strings.EqualFold(planType, "free") && paidPlan == "" {
|
||||
paidPlan = planType
|
||||
}
|
||||
}
|
||||
// 优先级:default > 非 free > 任意
|
||||
switch {
|
||||
case defaultPlan != "":
|
||||
info.PlanType = defaultPlan
|
||||
case paidPlan != "":
|
||||
info.PlanType = paidPlan
|
||||
default:
|
||||
info.PlanType = anyPlan
|
||||
}
|
||||
}
|
||||
|
||||
if info.PlanType == "" {
|
||||
slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300))
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID)
|
||||
return info
|
||||
}
|
||||
|
||||
// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type
|
||||
func extractPlanFromAccount(accounts map[string]any, accountKey string) string {
|
||||
acctRaw, ok := accounts[accountKey]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
acct, ok := acctRaw.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return extractPlanType(acct)
|
||||
}
|
||||
|
||||
// extractPlanType 从单个 account 对象中提取 plan_type
|
||||
func extractPlanType(acct map[string]any) string {
|
||||
if account, ok := acct["account"].(map[string]any); ok {
|
||||
if planType, ok := account["plan_type"].(string); ok && planType != "" {
|
||||
return planType
|
||||
}
|
||||
}
|
||||
if entitlement, ok := acct["entitlement"].(map[string]any); ok {
|
||||
if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" {
|
||||
return subPlan
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
|
||||
@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
|
||||
require.Zero(t, boundAccountID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(24)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleAccount := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
dbAccount := Account{
|
||||
ID: 13,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
RateLimitResetAt: &rateLimitedUntil,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
accountsByID: map[int64]*Account{dbAccount.ID: staleAccount},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
|
||||
require.NoError(t, getErr)
|
||||
require.Zero(t, boundAccountID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
|
||||
@ -3840,6 +3840,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil, nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||
if account == nil {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
|
||||
@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
_ = privacyMode
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
|
||||
@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
|
||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: opsAccountsPageSize,
|
||||
}, platformFilter, "", "", "", 0)
|
||||
}, platformFilter, "", "", "", 0, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -62,6 +62,12 @@ type OpsErrorLog struct {
|
||||
ClientIP *string `json:"client_ip"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Stream bool `json:"stream"`
|
||||
|
||||
InboundEndpoint string `json:"inbound_endpoint"`
|
||||
UpstreamEndpoint string `json:"upstream_endpoint"`
|
||||
RequestedModel string `json:"requested_model"`
|
||||
UpstreamModel string `json:"upstream_model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
}
|
||||
|
||||
type OpsErrorLogDetail struct {
|
||||
|
||||
@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct {
|
||||
Model string
|
||||
RequestPath string
|
||||
Stream bool
|
||||
// InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint string
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint string
|
||||
// RequestedModel is the client-requested model name before mapping.
|
||||
RequestedModel string
|
||||
// UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping.
|
||||
UpstreamModel string
|
||||
// RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2.
|
||||
// Matches service.RequestType enum semantics from usage_log.go.
|
||||
RequestType *int16
|
||||
UserAgent string
|
||||
|
||||
ErrorPhase string
|
||||
|
||||
@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct {
|
||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||
|
||||
// UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped).
|
||||
// Helps debug 404/routing errors by showing which endpoint was targeted.
|
||||
UpstreamURL string `json:"upstream_url,omitempty"`
|
||||
|
||||
// Best-effort upstream request capture (sanitized+trimmed).
|
||||
// Required for retrying a specific upstream attempt.
|
||||
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
|
||||
@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
|
||||
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
|
||||
ev.Kind = strings.TrimSpace(ev.Kind)
|
||||
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
|
||||
ev.Message = strings.TrimSpace(ev.Message)
|
||||
ev.Detail = strings.TrimSpace(ev.Detail)
|
||||
if ev.Message != "" {
|
||||
@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment
|
||||
// to avoid leaking sensitive query parameters (e.g. OAuth tokens).
|
||||
func safeUpstreamURL(rawURL string) string {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.IndexByte(rawURL, '?'); idx >= 0 {
|
||||
rawURL = rawURL[:idx]
|
||||
}
|
||||
if idx := strings.IndexByte(rawURL, '#'); idx >= 0 {
|
||||
rawURL = rawURL[:idx]
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
|
||||
@ -8,6 +8,27 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeUpstreamURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"},
|
||||
{"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"},
|
||||
{"strips both", "https://host/path?token=secret#x", "https://host/path"},
|
||||
{"no query or fragment", "https://host/path", "https://host/path"},
|
||||
{"empty string", "", ""},
|
||||
{"whitespace only", " ", ""},
|
||||
{"query before fragment", "https://h/p?a=1#f", "https://h/p"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, safeUpstreamURL(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
|
||||
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||
} else {
|
||||
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||
|
||||
@ -15,9 +15,11 @@ import (
|
||||
|
||||
type rateLimitAccountRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
tempCalls int
|
||||
lastErrorMsg string
|
||||
setErrorCalls int
|
||||
tempCalls int
|
||||
updateCredentialsCalls int
|
||||
lastCredentials map[string]any
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||
r.updateCredentialsCalls++
|
||||
r.lastCredentials = cloneCredentials(credentials)
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorRecorder struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
}
|
||||
|
||||
@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.NotEmpty(t, repo.lastCredentials["expires_at"])
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic(
|
||||
func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) {
|
||||
|
||||
@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
|
||||
}
|
||||
|
||||
if c.accountRepo != nil {
|
||||
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
|
||||
if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
newCredentials, err = refresher.Refresh(ctx, account)
|
||||
if newCredentials != nil {
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,19 +14,40 @@ import (
|
||||
|
||||
type tokenRefreshAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updateCalls int
|
||||
setErrorCalls int
|
||||
clearTempCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
updateCalls int
|
||||
fullUpdateCalls int
|
||||
updateCredentialsCalls int
|
||||
setErrorCalls int
|
||||
clearTempCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.fullUpdateCalls++
|
||||
r.lastAccount = account
|
||||
return r.updateErr
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||
r.updateCalls++
|
||||
r.updateCredentialsCalls++
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
cloned := cloneCredentials(credentials)
|
||||
if r.accountsByID != nil {
|
||||
if acc, ok := r.accountsByID[id]; ok && acc != nil {
|
||||
acc.Credentials = cloned
|
||||
r.lastAccount = acc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
r.lastAccount = &Account{ID: id, Credentials: cloned}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
||||
}
|
||||
@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
RateLimitResetAt: &resetAt,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
},
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "new-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
||||
@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
-- Ops error logs: add endpoint, model mapping, and request_type fields
|
||||
-- to match usage_logs observability coverage.
|
||||
--
|
||||
-- All columns are nullable with no default to preserve backward compatibility
|
||||
-- with existing rows.
|
||||
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint)
|
||||
ALTER TABLE ops_error_logs
|
||||
ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256),
|
||||
ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256);
|
||||
|
||||
-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model)
|
||||
ALTER TABLE ops_error_logs
|
||||
ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100),
|
||||
ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
|
||||
|
||||
-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2)
|
||||
ALTER TABLE ops_error_logs
|
||||
ADD COLUMN IF NOT EXISTS request_type SMALLINT;
|
||||
|
||||
COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.';
|
||||
COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.';
|
||||
COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).';
|
||||
COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.';
|
||||
COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.';
|
||||
@ -36,6 +36,7 @@ export async function list(
|
||||
status?: string
|
||||
group?: string
|
||||
search?: string
|
||||
privacy_mode?: string
|
||||
lite?: string
|
||||
},
|
||||
options?: {
|
||||
@ -68,6 +69,7 @@ export async function listWithEtag(
|
||||
status?: string
|
||||
group?: string
|
||||
search?: string
|
||||
privacy_mode?: string
|
||||
lite?: string
|
||||
},
|
||||
options?: {
|
||||
@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
|
||||
export async function refreshOpenAIToken(
|
||||
refreshToken: string,
|
||||
proxyId?: number | null,
|
||||
endpoint: string = '/admin/openai/refresh-token'
|
||||
endpoint: string = '/admin/openai/refresh-token',
|
||||
clientId?: string
|
||||
): Promise<Record<string, unknown>> {
|
||||
const payload: { refresh_token: string; proxy_id?: number } = {
|
||||
const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = {
|
||||
refresh_token: refreshToken
|
||||
}
|
||||
if (proxyId) {
|
||||
payload.proxy_id = proxyId
|
||||
}
|
||||
if (clientId) {
|
||||
payload.client_id = clientId
|
||||
}
|
||||
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
|
||||
return data
|
||||
}
|
||||
|
||||
@ -969,6 +969,13 @@ export interface OpsErrorLog {
|
||||
client_ip?: string | null
|
||||
request_path?: string
|
||||
stream?: boolean
|
||||
|
||||
// Error observability context (endpoint + model mapping)
|
||||
inbound_endpoint?: string
|
||||
upstream_endpoint?: string
|
||||
requested_model?: string
|
||||
upstream_model?: string
|
||||
request_type?: number | null
|
||||
}
|
||||
|
||||
export interface OpsErrorDetail extends OpsErrorLog {
|
||||
|
||||
@ -599,6 +599,43 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI OAuth WS mode -->
|
||||
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<label
|
||||
id="bulk-edit-openai-ws-mode-label"
|
||||
class="input-label mb-0"
|
||||
for="bulk-edit-openai-ws-mode-enabled"
|
||||
>
|
||||
{{ t('admin.accounts.openai.wsMode') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="enableOpenAIWSMode"
|
||||
id="bulk-edit-openai-ws-mode-enabled"
|
||||
type="checkbox"
|
||||
aria-controls="bulk-edit-openai-ws-mode"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
id="bulk-edit-openai-ws-mode"
|
||||
:class="!enableOpenAIWSMode && 'pointer-events-none opacity-50'"
|
||||
>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.openai.wsModeDesc') }}
|
||||
</p>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t(openAIWSModeConcurrencyHintKey) }}
|
||||
</p>
|
||||
<Select
|
||||
v-model="openaiOAuthResponsesWebSocketV2Mode"
|
||||
data-testid="bulk-edit-openai-ws-mode-select"
|
||||
:options="openAIWSModeOptions"
|
||||
aria-labelledby="bulk-edit-openai-ws-mode-label"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- RPM Limit (仅全部为 Anthropic OAuth/SetupToken 时显示) -->
|
||||
<div v-if="allAnthropicOAuthOrSetupToken" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@ -821,6 +858,13 @@ import {
|
||||
buildModelMappingObject as buildModelMappingPayload,
|
||||
getPresetMappingsByPlatform
|
||||
} from '@/composables/useModelWhitelist'
|
||||
import {
|
||||
OPENAI_WS_MODE_OFF,
|
||||
OPENAI_WS_MODE_PASSTHROUGH,
|
||||
isOpenAIWSModeEnabled,
|
||||
resolveOpenAIWSModeConcurrencyHintKey
|
||||
} from '@/utils/openaiWsMode'
|
||||
import type { OpenAIWSMode } from '@/utils/openaiWsMode'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
@ -843,6 +887,15 @@ const appStore = useAppStore()
|
||||
// Platform awareness
|
||||
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
|
||||
|
||||
const allOpenAIOAuth = computed(() => {
|
||||
return (
|
||||
props.selectedPlatforms.length === 1 &&
|
||||
props.selectedPlatforms[0] === 'openai' &&
|
||||
props.selectedTypes.length > 0 &&
|
||||
props.selectedTypes.every(t => t === 'oauth')
|
||||
)
|
||||
})
|
||||
|
||||
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
|
||||
const allAnthropicOAuthOrSetupToken = computed(() => {
|
||||
return (
|
||||
@ -886,6 +939,7 @@ const enablePriority = ref(false)
|
||||
const enableRateMultiplier = ref(false)
|
||||
const enableStatus = ref(false)
|
||||
const enableGroups = ref(false)
|
||||
const enableOpenAIWSMode = ref(false)
|
||||
const enableRpmLimit = ref(false)
|
||||
|
||||
// State - field values
|
||||
@ -907,6 +961,7 @@ const priority = ref(1)
|
||||
const rateMultiplier = ref(1)
|
||||
const status = ref<'active' | 'inactive'>('active')
|
||||
const groupIds = ref<number[]>([])
|
||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||
const rpmLimitEnabled = ref(false)
|
||||
const bulkBaseRpm = ref<number | null>(null)
|
||||
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
|
||||
@ -933,6 +988,13 @@ const statusOptions = computed(() => [
|
||||
{ value: 'active', label: t('common.active') },
|
||||
{ value: 'inactive', label: t('common.inactive') }
|
||||
])
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
|
||||
])
|
||||
const openAIWSModeConcurrencyHintKey = computed(() =>
|
||||
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
|
||||
)
|
||||
|
||||
// Model mapping helpers
|
||||
const addModelMapping = () => {
|
||||
@ -1015,6 +1077,12 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
const updates: Record<string, unknown> = {}
|
||||
const credentials: Record<string, unknown> = {}
|
||||
let credentialsChanged = false
|
||||
const ensureExtra = (): Record<string, unknown> => {
|
||||
if (!updates.extra) {
|
||||
updates.extra = {}
|
||||
}
|
||||
return updates.extra as Record<string, unknown>
|
||||
}
|
||||
|
||||
if (enableProxy.value) {
|
||||
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
|
||||
@ -1089,9 +1157,17 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
updates.credentials = credentials
|
||||
}
|
||||
|
||||
if (enableOpenAIWSMode.value) {
|
||||
const extra = ensureExtra()
|
||||
extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||
extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
|
||||
openaiOAuthResponsesWebSocketV2Mode.value
|
||||
)
|
||||
}
|
||||
|
||||
// RPM limit settings (写入 extra 字段)
|
||||
if (enableRpmLimit.value) {
|
||||
const extra: Record<string, unknown> = {}
|
||||
const extra = ensureExtra()
|
||||
if (rpmLimitEnabled.value && bulkBaseRpm.value != null && bulkBaseRpm.value > 0) {
|
||||
extra.base_rpm = bulkBaseRpm.value
|
||||
extra.rpm_strategy = bulkRpmStrategy.value
|
||||
@ -1111,8 +1187,7 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
|
||||
// UMQ mode(独立于 RPM 保存)
|
||||
if (userMsgQueueMode.value !== null) {
|
||||
if (!updates.extra) updates.extra = {}
|
||||
const umqExtra = updates.extra as Record<string, unknown>
|
||||
const umqExtra = ensureExtra()
|
||||
umqExtra.user_msg_queue_mode = userMsgQueueMode.value // '' = 清除账号级覆盖
|
||||
umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge)
|
||||
}
|
||||
@ -1178,6 +1253,7 @@ const handleSubmit = async () => {
|
||||
enableRateMultiplier.value ||
|
||||
enableStatus.value ||
|
||||
enableGroups.value ||
|
||||
enableOpenAIWSMode.value ||
|
||||
enableRpmLimit.value ||
|
||||
userMsgQueueMode.value !== null
|
||||
|
||||
@ -1269,6 +1345,7 @@ watch(
|
||||
enableRateMultiplier.value = false
|
||||
enableStatus.value = false
|
||||
enableGroups.value = false
|
||||
enableOpenAIWSMode.value = false
|
||||
enableRpmLimit.value = false
|
||||
|
||||
// Reset all values
|
||||
@ -1286,6 +1363,7 @@ watch(
|
||||
rateMultiplier.value = 1
|
||||
status.value = 'active'
|
||||
groupIds.value = []
|
||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||
rpmLimitEnabled.value = false
|
||||
bulkBaseRpm.value = null
|
||||
bulkRpmStrategy.value = 'tiered'
|
||||
|
||||
@ -2504,6 +2504,7 @@
|
||||
:allow-multiple="form.platform === 'anthropic'"
|
||||
:show-cookie-option="form.platform === 'anthropic'"
|
||||
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
|
||||
:show-mobile-refresh-token-option="form.platform === 'openai'"
|
||||
:show-session-token-option="form.platform === 'sora'"
|
||||
:show-access-token-option="form.platform === 'sora'"
|
||||
:platform="form.platform"
|
||||
@ -2511,6 +2512,7 @@
|
||||
@generate-url="handleGenerateUrl"
|
||||
@cookie-auth="handleCookieAuth"
|
||||
@validate-refresh-token="handleValidateRefreshToken"
|
||||
@validate-mobile-refresh-token="handleOpenAIValidateMobileRT"
|
||||
@validate-session-token="handleValidateSessionToken"
|
||||
@import-access-token="handleImportAccessToken"
|
||||
/>
|
||||
@ -4360,11 +4362,14 @@ const handleOpenAIExchange = async (authCode: string) => {
|
||||
}
|
||||
|
||||
// OpenAI 手动 RT 批量验证和创建
|
||||
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
// OpenAI Mobile RT 使用的 client_id(与后端 openai.SoraClientID 一致)
|
||||
const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK'
|
||||
|
||||
// OpenAI/Sora RT 批量验证和创建(共享逻辑)
|
||||
const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => {
|
||||
const oauthClient = activeOpenAIOAuth.value
|
||||
if (!refreshTokenInput.trim()) return
|
||||
|
||||
// Parse multiple refresh tokens (one per line)
|
||||
const refreshTokens = refreshTokenInput
|
||||
.split('\n')
|
||||
.map((rt) => rt.trim())
|
||||
@ -4389,7 +4394,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
try {
|
||||
const tokenInfo = await oauthClient.validateRefreshToken(
|
||||
refreshTokens[i],
|
||||
form.proxy_id
|
||||
form.proxy_id,
|
||||
clientId
|
||||
)
|
||||
if (!tokenInfo) {
|
||||
failedCount++
|
||||
@ -4399,6 +4405,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
}
|
||||
|
||||
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||
if (clientId) {
|
||||
credentials.client_id = clientId
|
||||
}
|
||||
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||
const extra = buildOpenAIExtra(oauthExtra)
|
||||
|
||||
@ -4410,8 +4419,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Generate account name with index for batch
|
||||
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
|
||||
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
|
||||
const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName
|
||||
|
||||
let openaiAccountId: string | number | undefined
|
||||
|
||||
@ -4494,6 +4504,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
// 手动输入 RT(Codex CLI client_id,默认)
|
||||
const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt)
|
||||
|
||||
// 手动输入 Mobile RT(SoraClientID)
|
||||
const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID)
|
||||
|
||||
// Sora 手动 ST 批量验证和创建
|
||||
const handleSoraValidateST = async (sessionTokenInput: string) => {
|
||||
const oauthClient = activeOpenAIOAuth.value
|
||||
|
||||
@ -48,6 +48,17 @@
|
||||
t(getOAuthKey('refreshTokenAuth'))
|
||||
}}</span>
|
||||
</label>
|
||||
<label v-if="showMobileRefreshTokenOption" class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
v-model="inputMethod"
|
||||
type="radio"
|
||||
value="mobile_refresh_token"
|
||||
class="text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
||||
t('admin.accounts.oauth.openai.mobileRefreshTokenAuth', '手动输入 Mobile RT')
|
||||
}}</span>
|
||||
</label>
|
||||
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
v-model="inputMethod"
|
||||
@ -73,8 +84,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Refresh Token Input (OpenAI / Antigravity) -->
|
||||
<div v-if="inputMethod === 'refresh_token'" class="space-y-4">
|
||||
<!-- Refresh Token Input (OpenAI / Antigravity / Mobile RT) -->
|
||||
<div v-if="inputMethod === 'refresh_token' || inputMethod === 'mobile_refresh_token'" class="space-y-4">
|
||||
<div
|
||||
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
||||
>
|
||||
@ -759,6 +770,7 @@ interface Props {
|
||||
methodLabel?: string
|
||||
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
||||
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
||||
showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only)
|
||||
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
|
||||
showAccessTokenOption?: boolean // Whether to show access token input option (Sora only)
|
||||
platform?: AccountPlatform // Platform type for different UI/text
|
||||
@ -776,6 +788,7 @@ const props = withDefaults(defineProps<Props>(), {
|
||||
methodLabel: 'Authorization Method',
|
||||
showCookieOption: true,
|
||||
showRefreshTokenOption: false,
|
||||
showMobileRefreshTokenOption: false,
|
||||
showSessionTokenOption: false,
|
||||
showAccessTokenOption: false,
|
||||
platform: 'anthropic',
|
||||
@ -787,6 +800,7 @@ const emit = defineEmits<{
|
||||
'exchange-code': [code: string]
|
||||
'cookie-auth': [sessionKey: string]
|
||||
'validate-refresh-token': [refreshToken: string]
|
||||
'validate-mobile-refresh-token': [refreshToken: string]
|
||||
'validate-session-token': [sessionToken: string]
|
||||
'import-access-token': [accessToken: string]
|
||||
'update:inputMethod': [method: AuthInputMethod]
|
||||
@ -834,7 +848,7 @@ const oauthState = ref('')
|
||||
const projectId = ref('')
|
||||
|
||||
// Computed: show method selection when either cookie or refresh token option is enabled
|
||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
|
||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
|
||||
|
||||
// Clipboard
|
||||
const { copied, copyToClipboard } = useClipboard()
|
||||
@ -945,7 +959,11 @@ const handleCookieAuth = () => {
|
||||
|
||||
const handleValidateRefreshToken = () => {
|
||||
if (refreshTokenInput.value.trim()) {
|
||||
emit('validate-refresh-token', refreshTokenInput.value.trim())
|
||||
if (inputMethod.value === 'mobile_refresh_token') {
|
||||
emit('validate-mobile-refresh-token', refreshTokenInput.value.trim())
|
||||
} else {
|
||||
emit('validate-refresh-token', refreshTokenInput.value.trim())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -50,7 +50,21 @@ function mountModal(extraProps: Record<string, unknown> = {}) {
|
||||
stubs: {
|
||||
BaseDialog: { template: '<div><slot /><slot name="footer" /></div>' },
|
||||
ConfirmDialog: true,
|
||||
Select: true,
|
||||
Select: {
|
||||
props: ['modelValue', 'options'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<select
|
||||
v-bind="$attrs"
|
||||
:value="modelValue"
|
||||
@change="$emit('update:modelValue', $event.target.value)"
|
||||
>
|
||||
<option v-for="option in options" :key="option.value" :value="option.value">
|
||||
{{ option.label }}
|
||||
</option>
|
||||
</select>
|
||||
`
|
||||
},
|
||||
ProxySelector: true,
|
||||
GroupSelector: true,
|
||||
Icon: true
|
||||
@ -115,4 +129,33 @@ describe('BulkEditAccountModal', () => {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('OpenAI OAuth 批量编辑应提交 OAuth 专属 WS mode 字段', async () => {
|
||||
const wrapper = mountModal({
|
||||
selectedPlatforms: ['openai'],
|
||||
selectedTypes: ['oauth']
|
||||
})
|
||||
|
||||
await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true)
|
||||
await wrapper.get('[data-testid="bulk-edit-openai-ws-mode-select"]').setValue('passthrough')
|
||||
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||
await flushPromises()
|
||||
|
||||
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
|
||||
extra: {
|
||||
openai_oauth_responses_websockets_v2_mode: 'passthrough',
|
||||
openai_oauth_responses_websockets_v2_enabled: true
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('OpenAI API Key 批量编辑不显示 WS mode 入口', () => {
|
||||
const wrapper = mountModal({
|
||||
selectedPlatforms: ['openai'],
|
||||
selectedTypes: ['apikey']
|
||||
})
|
||||
|
||||
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
|
||||
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
|
||||
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
|
||||
<Select :model-value="filters.privacy_mode" class="w-40" :options="privacyOpts" @update:model-value="updatePrivacyMode" @change="$emit('change')" />
|
||||
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
|
||||
</div>
|
||||
</template>
|
||||
@ -22,10 +23,18 @@ const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); co
|
||||
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
|
||||
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
|
||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||
const updatePrivacyMode = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, privacy_mode: value }) }
|
||||
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
||||
const privacyOpts = computed(() => [
|
||||
{ value: '', label: t('admin.accounts.allPrivacyModes') },
|
||||
{ value: '__unset__', label: t('admin.accounts.privacyUnset') },
|
||||
{ value: 'training_off', label: 'Privacy' },
|
||||
{ value: 'training_set_cf_blocked', label: 'CF' },
|
||||
{ value: 'training_set_failed', label: 'Fail' }
|
||||
])
|
||||
const gOpts = computed(() => [
|
||||
{ value: '', label: t('admin.accounts.allGroups') },
|
||||
{ value: 'ungrouped', label: t('admin.accounts.ungroupedGroup') },
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { mount } from '@vue/test-utils'
|
||||
|
||||
import AccountTableFilters from '../AccountTableFilters.vue'
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
describe('AccountTableFilters', () => {
|
||||
it('renders privacy mode options and emits privacy_mode updates', async () => {
|
||||
const wrapper = mount(AccountTableFilters, {
|
||||
props: {
|
||||
searchQuery: '',
|
||||
filters: {
|
||||
platform: '',
|
||||
type: '',
|
||||
status: '',
|
||||
group: '',
|
||||
privacy_mode: ''
|
||||
},
|
||||
groups: []
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
SearchInput: {
|
||||
template: '<div />'
|
||||
},
|
||||
Select: {
|
||||
props: ['modelValue', 'options'],
|
||||
emits: ['update:modelValue', 'change'],
|
||||
template: '<div class="select-stub" :data-options="JSON.stringify(options)" />'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const selects = wrapper.findAll('.select-stub')
|
||||
expect(selects).toHaveLength(5)
|
||||
|
||||
const privacyOptions = JSON.parse(selects[3].attributes('data-options'))
|
||||
expect(privacyOptions).toEqual([
|
||||
{ value: '', label: 'admin.accounts.allPrivacyModes' },
|
||||
{ value: '__unset__', label: 'admin.accounts.privacyUnset' },
|
||||
{ value: 'training_off', label: 'Privacy' },
|
||||
{ value: 'training_set_cf_blocked', label: 'CF' },
|
||||
{ value: 'training_set_failed', label: 'Fail' }
|
||||
])
|
||||
})
|
||||
})
|
||||
@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
|
||||
export type AddMethod = 'oauth' | 'setup-token'
|
||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token'
|
||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'mobile_refresh_token' | 'session_token' | 'access_token'
|
||||
|
||||
export interface OAuthState {
|
||||
authUrl: string
|
||||
|
||||
@ -13,6 +13,8 @@ export interface OpenAITokenInfo {
|
||||
scope?: string
|
||||
email?: string
|
||||
name?: string
|
||||
plan_type?: string
|
||||
privacy_mode?: string
|
||||
// OpenAI specific IDs (extracted from ID Token)
|
||||
chatgpt_account_id?: string
|
||||
chatgpt_user_id?: string
|
||||
@ -126,9 +128,11 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||
}
|
||||
|
||||
// Validate refresh token and get full token info
|
||||
// clientId: 指定 OAuth client_id(用于第三方渠道获取的 RT,如 app_LlGpXReQgckcGGUo2JrYvtJK)
|
||||
const validateRefreshToken = async (
|
||||
refreshToken: string,
|
||||
proxyId?: number | null
|
||||
proxyId?: number | null,
|
||||
clientId?: string
|
||||
): Promise<OpenAITokenInfo | null> => {
|
||||
if (!refreshToken.trim()) {
|
||||
error.value = 'Missing refresh token'
|
||||
@ -143,11 +147,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
|
||||
refreshToken.trim(),
|
||||
proxyId,
|
||||
`${endpointPrefix}/refresh-token`
|
||||
`${endpointPrefix}/refresh-token`,
|
||||
clientId
|
||||
)
|
||||
return tokenInfo as OpenAITokenInfo
|
||||
} catch (err: any) {
|
||||
error.value = err.response?.data?.detail || 'Failed to validate refresh token'
|
||||
error.value = err.response?.data?.detail || err.message || 'Failed to validate refresh token'
|
||||
appStore.showError(error.value)
|
||||
return null
|
||||
} finally {
|
||||
@ -182,22 +187,23 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
// Build credentials for OpenAI OAuth account
|
||||
// Build credentials for OpenAI OAuth account (aligned with backend BuildAccountCredentials)
|
||||
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
|
||||
const creds: Record<string, unknown> = {
|
||||
access_token: tokenInfo.access_token,
|
||||
refresh_token: tokenInfo.refresh_token,
|
||||
token_type: tokenInfo.token_type,
|
||||
expires_in: tokenInfo.expires_in,
|
||||
expires_at: tokenInfo.expires_at,
|
||||
scope: tokenInfo.scope
|
||||
expires_at: tokenInfo.expires_at
|
||||
}
|
||||
|
||||
if (tokenInfo.client_id) {
|
||||
creds.client_id = tokenInfo.client_id
|
||||
// 仅在返回了新的 refresh_token 时才写入,防止用空值覆盖已有令牌
|
||||
if (tokenInfo.refresh_token) {
|
||||
creds.refresh_token = tokenInfo.refresh_token
|
||||
}
|
||||
if (tokenInfo.id_token) {
|
||||
creds.id_token = tokenInfo.id_token
|
||||
}
|
||||
if (tokenInfo.email) {
|
||||
creds.email = tokenInfo.email
|
||||
}
|
||||
|
||||
// Include OpenAI specific IDs (required for forwarding)
|
||||
if (tokenInfo.chatgpt_account_id) {
|
||||
creds.chatgpt_account_id = tokenInfo.chatgpt_account_id
|
||||
}
|
||||
@ -207,6 +213,12 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||
if (tokenInfo.organization_id) {
|
||||
creds.organization_id = tokenInfo.organization_id
|
||||
}
|
||||
if (tokenInfo.plan_type) {
|
||||
creds.plan_type = tokenInfo.plan_type
|
||||
}
|
||||
if (tokenInfo.client_id) {
|
||||
creds.client_id = tokenInfo.client_id
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
@ -220,6 +232,9 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||
if (tokenInfo.name) {
|
||||
extra.name = tokenInfo.name
|
||||
}
|
||||
if (tokenInfo.privacy_mode) {
|
||||
extra.privacy_mode = tokenInfo.privacy_mode
|
||||
}
|
||||
return Object.keys(extra).length > 0 ? extra : undefined
|
||||
}
|
||||
|
||||
|
||||
@ -1979,6 +1979,8 @@ export default {
|
||||
expiresAt: 'Expires At',
|
||||
actions: 'Actions'
|
||||
},
|
||||
allPrivacyModes: 'All Privacy States',
|
||||
privacyUnset: 'Unset',
|
||||
privacyTrainingOff: 'Training data sharing disabled',
|
||||
privacyCfBlocked: 'Blocked by Cloudflare, training may still be on',
|
||||
privacyFailed: 'Failed to disable training',
|
||||
@ -3494,7 +3496,12 @@ export default {
|
||||
typeRequest: 'Request',
|
||||
typeAuth: 'Auth',
|
||||
typeRouting: 'Routing',
|
||||
typeInternal: 'Internal'
|
||||
typeInternal: 'Internal',
|
||||
endpoint: 'Endpoint',
|
||||
requestType: 'Type',
|
||||
requestTypeSync: 'Sync',
|
||||
requestTypeStream: 'Stream',
|
||||
requestTypeWs: 'WS'
|
||||
},
|
||||
// Error Details Modal
|
||||
errorDetails: {
|
||||
@ -3580,6 +3587,16 @@ export default {
|
||||
latency: 'Request Duration',
|
||||
businessLimited: 'Business Limited',
|
||||
requestPath: 'Request Path',
|
||||
inboundEndpoint: 'Inbound Endpoint',
|
||||
upstreamEndpoint: 'Upstream Endpoint',
|
||||
requestedModel: 'Requested Model',
|
||||
upstreamModel: 'Upstream Model',
|
||||
requestType: 'Request Type',
|
||||
requestTypeUnknown: 'Unknown',
|
||||
requestTypeSync: 'Sync',
|
||||
requestTypeStream: 'Stream',
|
||||
requestTypeWs: 'WebSocket',
|
||||
modelMapping: 'Model Mapping',
|
||||
timings: 'Timings',
|
||||
auth: 'Auth',
|
||||
routing: 'Routing',
|
||||
|
||||
@ -2017,6 +2017,8 @@ export default {
|
||||
expiresAt: '过期时间',
|
||||
actions: '操作'
|
||||
},
|
||||
allPrivacyModes: '全部Privacy状态',
|
||||
privacyUnset: '未设置',
|
||||
privacyTrainingOff: '已关闭训练数据共享',
|
||||
privacyCfBlocked: '被 Cloudflare 拦截,训练可能仍开启',
|
||||
privacyFailed: '关闭训练数据共享失败',
|
||||
@ -3659,7 +3661,12 @@ export default {
|
||||
typeRequest: '请求',
|
||||
typeAuth: '认证',
|
||||
typeRouting: '路由',
|
||||
typeInternal: '内部'
|
||||
typeInternal: '内部',
|
||||
endpoint: '端点',
|
||||
requestType: '类型',
|
||||
requestTypeSync: '同步',
|
||||
requestTypeStream: '流式',
|
||||
requestTypeWs: 'WS'
|
||||
},
|
||||
// Error Details Modal
|
||||
errorDetails: {
|
||||
@ -3745,6 +3752,16 @@ export default {
|
||||
latency: '请求时长',
|
||||
businessLimited: '业务限制',
|
||||
requestPath: '请求路径',
|
||||
inboundEndpoint: '入站端点',
|
||||
upstreamEndpoint: '上游端点',
|
||||
requestedModel: '请求模型',
|
||||
upstreamModel: '上游模型',
|
||||
requestType: '请求类型',
|
||||
requestTypeUnknown: '未知',
|
||||
requestTypeSync: '同步',
|
||||
requestTypeStream: '流式',
|
||||
requestTypeWs: 'WebSocket',
|
||||
modelMapping: '模型映射',
|
||||
timings: '时序信息',
|
||||
auth: '认证',
|
||||
routing: '路由',
|
||||
|
||||
@ -581,7 +581,7 @@ const {
|
||||
handlePageSizeChange: baseHandlePageSizeChange
|
||||
} = useTableLoader<Account, any>({
|
||||
fetchFn: adminAPI.accounts.list,
|
||||
initialParams: { platform: '', type: '', status: '', group: '', search: '' }
|
||||
initialParams: { platform: '', type: '', status: '', privacy_mode: '', group: '', search: '' }
|
||||
})
|
||||
|
||||
const {
|
||||
@ -758,6 +758,7 @@ const refreshAccountsIncrementally = async () => {
|
||||
platform?: string
|
||||
type?: string
|
||||
status?: string
|
||||
privacy_mode?: string
|
||||
group?: string
|
||||
search?: string
|
||||
|
||||
|
||||
@ -1655,7 +1655,7 @@
|
||||
<button
|
||||
type="button"
|
||||
@click="testSmtpConnection"
|
||||
:disabled="testingSmtp"
|
||||
:disabled="testingSmtp || loadFailed"
|
||||
class="btn btn-secondary btn-sm"
|
||||
>
|
||||
<svg v-if="testingSmtp" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
@ -1725,6 +1725,11 @@
|
||||
v-model="form.smtp_password"
|
||||
type="password"
|
||||
class="input"
|
||||
autocomplete="new-password"
|
||||
autocapitalize="off"
|
||||
spellcheck="false"
|
||||
@keydown="smtpPasswordManuallyEdited = true"
|
||||
@paste="smtpPasswordManuallyEdited = true"
|
||||
:placeholder="
|
||||
form.smtp_password_configured
|
||||
? t('admin.settings.smtp.passwordConfiguredPlaceholder')
|
||||
@ -1807,7 +1812,7 @@
|
||||
<button
|
||||
type="button"
|
||||
@click="sendTestEmail"
|
||||
:disabled="sendingTestEmail || !testEmailAddress"
|
||||
:disabled="sendingTestEmail || !testEmailAddress || loadFailed"
|
||||
class="btn btn-secondary"
|
||||
>
|
||||
<svg
|
||||
@ -1853,7 +1858,7 @@
|
||||
|
||||
<!-- Save Button -->
|
||||
<div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end">
|
||||
<button type="submit" :disabled="saving" class="btn btn-primary">
|
||||
<button type="submit" :disabled="saving || loadFailed" class="btn btn-primary">
|
||||
<svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
<circle
|
||||
class="opacity-25"
|
||||
@ -1924,9 +1929,11 @@ const settingsTabs = [
|
||||
const { copyToClipboard } = useClipboard()
|
||||
|
||||
const loading = ref(true)
|
||||
const loadFailed = ref(false)
|
||||
const saving = ref(false)
|
||||
const testingSmtp = ref(false)
|
||||
const sendingTestEmail = ref(false)
|
||||
const smtpPasswordManuallyEdited = ref(false)
|
||||
const testEmailAddress = ref('')
|
||||
const registrationEmailSuffixWhitelistTags = ref<string[]>([])
|
||||
const registrationEmailSuffixWhitelistDraft = ref('')
|
||||
@ -2201,6 +2208,7 @@ function removeEndpoint(index: number) {
|
||||
|
||||
async function loadSettings() {
|
||||
loading.value = true
|
||||
loadFailed.value = false
|
||||
try {
|
||||
const settings = await adminAPI.settings.getSettings()
|
||||
Object.assign(form, settings)
|
||||
@ -2218,9 +2226,11 @@ async function loadSettings() {
|
||||
)
|
||||
registrationEmailSuffixWhitelistDraft.value = ''
|
||||
form.smtp_password = ''
|
||||
smtpPasswordManuallyEdited.value = false
|
||||
form.turnstile_secret_key = ''
|
||||
form.linuxdo_connect_client_secret = ''
|
||||
} catch (error: any) {
|
||||
loadFailed.value = true
|
||||
appStore.showError(
|
||||
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
|
||||
)
|
||||
@ -2372,6 +2382,7 @@ async function saveSettings() {
|
||||
)
|
||||
registrationEmailSuffixWhitelistDraft.value = ''
|
||||
form.smtp_password = ''
|
||||
smtpPasswordManuallyEdited.value = false
|
||||
form.turnstile_secret_key = ''
|
||||
form.linuxdo_connect_client_secret = ''
|
||||
// Refresh cached settings so sidebar/header update immediately
|
||||
@ -2390,11 +2401,12 @@ async function saveSettings() {
|
||||
async function testSmtpConnection() {
|
||||
testingSmtp.value = true
|
||||
try {
|
||||
const smtpPasswordForTest = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
|
||||
const result = await adminAPI.settings.testSmtpConnection({
|
||||
smtp_host: form.smtp_host,
|
||||
smtp_port: form.smtp_port,
|
||||
smtp_username: form.smtp_username,
|
||||
smtp_password: form.smtp_password,
|
||||
smtp_password: smtpPasswordForTest,
|
||||
smtp_use_tls: form.smtp_use_tls
|
||||
})
|
||||
// API returns { message: "..." } on success, errors are thrown as exceptions
|
||||
@ -2416,12 +2428,13 @@ async function sendTestEmail() {
|
||||
|
||||
sendingTestEmail.value = true
|
||||
try {
|
||||
const smtpPasswordForSend = smtpPasswordManuallyEdited.value ? form.smtp_password : ''
|
||||
const result = await adminAPI.settings.sendTestEmail({
|
||||
email: testEmailAddress.value,
|
||||
smtp_host: form.smtp_host,
|
||||
smtp_port: form.smtp_port,
|
||||
smtp_username: form.smtp_username,
|
||||
smtp_password: form.smtp_password,
|
||||
smtp_password: smtpPasswordForSend,
|
||||
smtp_from_email: form.smtp_from_email,
|
||||
smtp_from_name: form.smtp_from_name,
|
||||
smtp_use_tls: form.smtp_use_tls
|
||||
|
||||
@ -59,7 +59,28 @@
|
||||
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
|
||||
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.model') }}</div>
|
||||
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ detail.model || '—' }}
|
||||
<template v-if="hasModelMapping(detail)">
|
||||
<span class="font-mono">{{ detail.requested_model }}</span>
|
||||
<span class="mx-1 text-gray-400">→</span>
|
||||
<span class="font-mono text-primary-600 dark:text-primary-400">{{ detail.upstream_model }}</span>
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ displayModel(detail) || '—' }}
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
|
||||
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.inboundEndpoint') }}</div>
|
||||
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ detail.inbound_endpoint || '—' }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
|
||||
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.upstreamEndpoint') }}</div>
|
||||
<div class="mt-1 break-all font-mono text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ detail.upstream_endpoint || '—' }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -72,6 +93,13 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
|
||||
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.requestType') }}</div>
|
||||
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ formatRequestTypeLabel(detail.request_type) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="rounded-xl bg-gray-50 p-4 dark:bg-dark-900">
|
||||
<div class="text-xs font-bold uppercase tracking-wider text-gray-400">{{ t('admin.ops.errorDetail.message') }}</div>
|
||||
<div class="mt-1 truncate text-sm font-medium text-gray-900 dark:text-white" :title="detail.message">
|
||||
@ -213,6 +241,31 @@ function isUpstreamError(d: OpsErrorDetail | null): boolean {
|
||||
return phase === 'upstream' && owner === 'provider'
|
||||
}
|
||||
|
||||
function formatRequestTypeLabel(type: number | null | undefined): string {
|
||||
switch (type) {
|
||||
case 1: return t('admin.ops.errorDetail.requestTypeSync')
|
||||
case 2: return t('admin.ops.errorDetail.requestTypeStream')
|
||||
case 3: return t('admin.ops.errorDetail.requestTypeWs')
|
||||
default: return t('admin.ops.errorDetail.requestTypeUnknown')
|
||||
}
|
||||
}
|
||||
|
||||
function hasModelMapping(d: OpsErrorDetail | null): boolean {
|
||||
if (!d) return false
|
||||
const requested = String(d.requested_model || '').trim()
|
||||
const upstream = String(d.upstream_model || '').trim()
|
||||
return !!requested && !!upstream && requested !== upstream
|
||||
}
|
||||
|
||||
function displayModel(d: OpsErrorDetail | null): string {
|
||||
if (!d) return ''
|
||||
const upstream = String(d.upstream_model || '').trim()
|
||||
if (upstream) return upstream
|
||||
const requested = String(d.requested_model || '').trim()
|
||||
if (requested) return requested
|
||||
return String(d.model || '').trim()
|
||||
}
|
||||
|
||||
const correlatedUpstream = ref<OpsErrorDetail[]>([])
|
||||
const correlatedUpstreamLoading = ref(false)
|
||||
|
||||
|
||||
@ -17,6 +17,9 @@
|
||||
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
|
||||
{{ t('admin.ops.errorLog.type') }}
|
||||
</th>
|
||||
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
|
||||
{{ t('admin.ops.errorLog.endpoint') }}
|
||||
</th>
|
||||
<th class="border-b border-gray-200 px-4 py-2.5 text-left text-[11px] font-bold uppercase tracking-wider text-gray-500 dark:border-dark-700 dark:text-dark-400">
|
||||
{{ t('admin.ops.errorLog.platform') }}
|
||||
</th>
|
||||
@ -42,7 +45,7 @@
|
||||
</thead>
|
||||
<tbody class="divide-y divide-gray-100 dark:divide-dark-700">
|
||||
<tr v-if="rows.length === 0">
|
||||
<td colspan="9" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500">
|
||||
<td colspan="10" class="py-12 text-center text-sm text-gray-400 dark:text-dark-500">
|
||||
{{ t('admin.ops.errorLog.noErrors') }}
|
||||
</td>
|
||||
</tr>
|
||||
@ -74,6 +77,18 @@
|
||||
</span>
|
||||
</td>
|
||||
|
||||
<!-- Endpoint -->
|
||||
<td class="px-4 py-2">
|
||||
<div class="max-w-[160px]">
|
||||
<el-tooltip v-if="log.inbound_endpoint" :content="formatEndpointTooltip(log)" placement="top" :show-after="500">
|
||||
<span class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||
{{ log.inbound_endpoint }}
|
||||
</span>
|
||||
</el-tooltip>
|
||||
<span v-else class="text-xs text-gray-400">-</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<!-- Platform -->
|
||||
<td class="whitespace-nowrap px-4 py-2">
|
||||
<span class="inline-flex items-center rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold uppercase text-gray-600 dark:bg-dark-700 dark:text-gray-300">
|
||||
@ -83,11 +98,22 @@
|
||||
|
||||
<!-- Model -->
|
||||
<td class="px-4 py-2">
|
||||
<div class="max-w-[120px] truncate" :title="log.model">
|
||||
<span v-if="log.model" class="font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||
{{ log.model }}
|
||||
</span>
|
||||
<span v-else class="text-xs text-gray-400">-</span>
|
||||
<div class="max-w-[160px]">
|
||||
<template v-if="hasModelMapping(log)">
|
||||
<el-tooltip :content="modelMappingTooltip(log)" placement="top" :show-after="500">
|
||||
<span class="flex items-center gap-1 truncate font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||
<span class="truncate">{{ log.requested_model }}</span>
|
||||
<span class="flex-shrink-0 text-gray-400">→</span>
|
||||
<span class="truncate text-primary-600 dark:text-primary-400">{{ log.upstream_model }}</span>
|
||||
</span>
|
||||
</el-tooltip>
|
||||
</template>
|
||||
<template v-else>
|
||||
<span v-if="displayModel(log)" class="truncate font-mono text-[11px] text-gray-700 dark:text-gray-300" :title="displayModel(log)">
|
||||
{{ displayModel(log) }}
|
||||
</span>
|
||||
<span v-else class="text-xs text-gray-400">-</span>
|
||||
</template>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
@ -138,6 +164,12 @@
|
||||
>
|
||||
{{ log.severity }}
|
||||
</span>
|
||||
<span
|
||||
v-if="log.request_type != null && log.request_type > 0"
|
||||
class="rounded bg-gray-100 px-1.5 py-0.5 text-[10px] font-bold text-gray-600 dark:bg-dark-700 dark:text-gray-300"
|
||||
>
|
||||
{{ formatRequestType(log.request_type) }}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
|
||||
@ -193,6 +225,44 @@ function isUpstreamRow(log: OpsErrorLog): boolean {
|
||||
return phase === 'upstream' && owner === 'provider'
|
||||
}
|
||||
|
||||
function formatEndpointTooltip(log: OpsErrorLog): string {
|
||||
const parts: string[] = []
|
||||
if (log.inbound_endpoint) parts.push(`Inbound: ${log.inbound_endpoint}`)
|
||||
if (log.upstream_endpoint) parts.push(`Upstream: ${log.upstream_endpoint}`)
|
||||
return parts.join('\n') || ''
|
||||
}
|
||||
|
||||
function hasModelMapping(log: OpsErrorLog): boolean {
|
||||
const requested = String(log.requested_model || '').trim()
|
||||
const upstream = String(log.upstream_model || '').trim()
|
||||
return !!requested && !!upstream && requested !== upstream
|
||||
}
|
||||
|
||||
function modelMappingTooltip(log: OpsErrorLog): string {
|
||||
const requested = String(log.requested_model || '').trim()
|
||||
const upstream = String(log.upstream_model || '').trim()
|
||||
if (!requested && !upstream) return ''
|
||||
if (requested && upstream) return `${requested} → ${upstream}`
|
||||
return upstream || requested
|
||||
}
|
||||
|
||||
function displayModel(log: OpsErrorLog): string {
|
||||
const upstream = String(log.upstream_model || '').trim()
|
||||
if (upstream) return upstream
|
||||
const requested = String(log.requested_model || '').trim()
|
||||
if (requested) return requested
|
||||
return String(log.model || '').trim()
|
||||
}
|
||||
|
||||
function formatRequestType(type: number | null | undefined): string {
|
||||
switch (type) {
|
||||
case 1: return t('admin.ops.errorLog.requestTypeSync')
|
||||
case 2: return t('admin.ops.errorLog.requestTypeStream')
|
||||
case 3: return t('admin.ops.errorLog.requestTypeWs')
|
||||
default: return ''
|
||||
}
|
||||
}
|
||||
|
||||
function getTypeBadge(log: OpsErrorLog): { label: string; className: string } {
|
||||
const phase = String(log.phase || '').toLowerCase()
|
||||
const owner = String(log.error_owner || '').toLowerCase()
|
||||
@ -263,4 +333,4 @@ function formatSmartMessage(msg: string): string {
|
||||
return msg.length > 200 ? msg.substring(0, 200) + '...' : msg
|
||||
|
||||
}
|
||||
</script>
|
||||
</script>
|
||||
|
||||
@ -344,7 +344,7 @@ onMounted(async () => {
|
||||
<div class="text-xs font-semibold text-gray-700 dark:text-gray-200">运行时日志配置(实时生效)</div>
|
||||
<span v-if="runtimeLoading" class="text-xs text-gray-500">加载中...</span>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-6">
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="runtimeConfig.level" class="input mt-1">
|
||||
@ -374,21 +374,27 @@ onMounted(async () => {
|
||||
保留天数
|
||||
<input v-model.number="runtimeConfig.retention_days" type="number" min="1" max="3650" class="input mt-1" />
|
||||
</label>
|
||||
<div class="flex items-end gap-2">
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.caller" type="checkbox" />
|
||||
caller
|
||||
</label>
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.enable_sampling" type="checkbox" />
|
||||
sampling
|
||||
</label>
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
|
||||
{{ runtimeSaving ? '保存中...' : '保存并生效' }}
|
||||
</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
|
||||
回滚默认值
|
||||
</button>
|
||||
<div class="md:col-span-2 xl:col-span-6">
|
||||
<div class="grid gap-3 lg:grid-cols-[minmax(0,1fr)_auto] lg:items-end">
|
||||
<div class="flex flex-wrap items-center gap-x-4 gap-y-2">
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.caller" type="checkbox" />
|
||||
caller
|
||||
</label>
|
||||
<label class="inline-flex items-center gap-2 text-xs text-gray-600 dark:text-gray-300">
|
||||
<input v-model="runtimeConfig.enable_sampling" type="checkbox" />
|
||||
sampling
|
||||
</label>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2 lg:justify-end">
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="runtimeSaving" @click="saveRuntimeConfig">
|
||||
{{ runtimeSaving ? '保存中...' : '保存并生效' }}
|
||||
</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" :disabled="runtimeSaving" @click="resetRuntimeConfig">
|
||||
回滚默认值
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<p v-if="health.last_error" class="mt-2 text-xs text-red-600 dark:text-red-400">最近写入错误:{{ health.last_error }}</p>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user