Add OpenAI embeddings gateway
This commit is contained in:
parent
89d96f4b25
commit
ccace69d4e
@ -17,6 +17,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
EndpointMessages = "/v1/messages"
|
EndpointMessages = "/v1/messages"
|
||||||
EndpointChatCompletions = "/v1/chat/completions"
|
EndpointChatCompletions = "/v1/chat/completions"
|
||||||
|
EndpointEmbeddings = "/v1/embeddings"
|
||||||
EndpointResponses = "/v1/responses"
|
EndpointResponses = "/v1/responses"
|
||||||
EndpointImagesGenerations = "/v1/images/generations"
|
EndpointImagesGenerations = "/v1/images/generations"
|
||||||
EndpointImagesEdits = "/v1/images/edits"
|
EndpointImagesEdits = "/v1/images/edits"
|
||||||
@ -42,6 +43,8 @@ const (
|
|||||||
func NormalizeInboundEndpoint(path string) string {
|
func NormalizeInboundEndpoint(path string) string {
|
||||||
path = strings.TrimSpace(path)
|
path = strings.TrimSpace(path)
|
||||||
switch {
|
switch {
|
||||||
|
case strings.Contains(path, EndpointEmbeddings):
|
||||||
|
return EndpointEmbeddings
|
||||||
case strings.Contains(path, EndpointChatCompletions):
|
case strings.Contains(path, EndpointChatCompletions):
|
||||||
return EndpointChatCompletions
|
return EndpointChatCompletions
|
||||||
case strings.Contains(path, EndpointMessages):
|
case strings.Contains(path, EndpointMessages):
|
||||||
@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
|||||||
|
|
||||||
switch platform {
|
switch platform {
|
||||||
case service.PlatformOpenAI:
|
case service.PlatformOpenAI:
|
||||||
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
||||||
return inbound
|
return inbound
|
||||||
}
|
}
|
||||||
// OpenAI forwards everything to the Responses API.
|
// OpenAI forwards everything to the Responses API.
|
||||||
|
|||||||
@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
|||||||
// Direct canonical paths.
|
// Direct canonical paths.
|
||||||
{"/v1/messages", EndpointMessages},
|
{"/v1/messages", EndpointMessages},
|
||||||
{"/v1/chat/completions", EndpointChatCompletions},
|
{"/v1/chat/completions", EndpointChatCompletions},
|
||||||
|
{"/v1/embeddings", EndpointEmbeddings},
|
||||||
{"/v1/responses", EndpointResponses},
|
{"/v1/responses", EndpointResponses},
|
||||||
{"/v1/images/generations", EndpointImagesGenerations},
|
{"/v1/images/generations", EndpointImagesGenerations},
|
||||||
{"/v1/images/edits", EndpointImagesEdits},
|
{"/v1/images/edits", EndpointImagesEdits},
|
||||||
@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
|||||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings},
|
||||||
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
|
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
|
||||||
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
|
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
|
||||||
|
|
||||||
|
|||||||
253
backend/internal/handler/openai_embeddings.go
Normal file
253
backend/internal/handler/openai_embeddings.go
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Embeddings handles the OpenAI-compatible Embeddings API.
|
||||||
|
// POST /v1/embeddings
|
||||||
|
func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
||||||
|
streamStarted := false
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.embeddings",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
)
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqModel := modelResult.String()
|
||||||
|
reqLog = reqLog.With(zap.String("model", reqModel))
|
||||||
|
setOpsRequestContext(c, reqModel, false)
|
||||||
|
setOpsEndpointContext(c, "", int16(service.RequestTypeSync))
|
||||||
|
|
||||||
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
|
|
||||||
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||||
|
reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err))
|
||||||
|
status, code, message, retryAfter := billingErrorDetails(err)
|
||||||
|
if retryAfter > 0 {
|
||||||
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||||
|
}
|
||||||
|
h.errorResponse(c, status, code, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
|
switchCount := 0
|
||||||
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
|
if maxAccountSwitches <= 0 {
|
||||||
|
maxAccountSwitches = 3
|
||||||
|
}
|
||||||
|
routingStart := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportHTTPSSE,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai_embeddings.account_select_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverErr, false)
|
||||||
|
} else {
|
||||||
|
h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
markOpsRoutingCapacityLimited(c)
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account := selection.Account
|
||||||
|
if account.Type != service.AccountTypeAPIKey {
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
|
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
|
||||||
|
if !accountAcquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
|
|
||||||
|
forwardBody := body
|
||||||
|
if channelMapping.Mapped {
|
||||||
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
|
}
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
|
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||||
|
defer func() {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
reqLog.Warn("openai_embeddings.upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
if c.Writer.Size() == writerSizeBeforeForward {
|
||||||
|
h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||||
|
}
|
||||||
|
reqLog.Warn("openai_embeddings.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
|
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||||
|
}); err != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.embeddings"),
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.String("model", reqModel),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Error("openai_embeddings.record_usage_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
reqLog.Debug("openai_embeddings.request_completed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -89,6 +89,19 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
h.Gateway.ChatCompletions(c)
|
h.Gateway.ChatCompletions(c)
|
||||||
})
|
})
|
||||||
|
gateway.POST("/embeddings", func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Embeddings API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Embeddings(c)
|
||||||
|
})
|
||||||
gateway.POST("/images/generations", func(c *gin.Context) {
|
gateway.POST("/images/generations", func(c *gin.Context) {
|
||||||
if getGroupPlatform(c) != service.PlatformOpenAI {
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
||||||
@ -158,6 +171,19 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
h.Gateway.ChatCompletions(c)
|
h.Gateway.ChatCompletions(c)
|
||||||
})
|
})
|
||||||
|
r.POST("/embeddings", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Embeddings API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Embeddings(c)
|
||||||
|
})
|
||||||
r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
||||||
if getGroupPlatform(c) != service.PlatformOpenAI {
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
|
||||||
|
|||||||
240
backend/internal/service/openai_embeddings.go
Normal file
240
backend/internal/service/openai_embeddings.go
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) ForwardEmbeddings(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
body []byte,
|
||||||
|
defaultMappedModel string,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||||
|
if originalModel == "" {
|
||||||
|
writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
return nil, fmt.Errorf("missing model in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
|
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||||
|
upstreamBody := body
|
||||||
|
if upstreamModel != originalModel {
|
||||||
|
upstreamBody = ReplaceModelInBody(body, upstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.L().Debug("openai embeddings: forwarding",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("original_model", originalModel),
|
||||||
|
zap.String("billing_model", billingModel),
|
||||||
|
zap.String("upstream_model", upstreamModel),
|
||||||
|
)
|
||||||
|
|
||||||
|
apiKey := account.GetOpenAIApiKey()
|
||||||
|
if apiKey == "" {
|
||||||
|
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||||
|
}
|
||||||
|
baseURL := account.GetOpenAIBaseURL()
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com"
|
||||||
|
}
|
||||||
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||||
|
}
|
||||||
|
targetURL := buildOpenAIEmbeddingsURL(validatedURL)
|
||||||
|
|
||||||
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||||
|
releaseUpstreamCtx()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
|
}
|
||||||
|
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
|
||||||
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||||
|
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
upstreamReq.Header.Set("Accept", "application/json")
|
||||||
|
for key, values := range c.Request.Header {
|
||||||
|
lowerKey := strings.ToLower(key)
|
||||||
|
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||||
|
for _, v := range values {
|
||||||
|
upstreamReq.Header.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
|
||||||
|
upstreamReq.Header.Set("user-agent", customUA)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
|
||||||
|
return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
|
||||||
|
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")),
|
||||||
|
Usage: extractOpenAIEmbeddingsUsage(respBody),
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: billingModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) {
|
||||||
|
if c == nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.Writer.Written() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.Header != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter)
|
||||||
|
}
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||||
|
c.Writer.Header().Set("Content-Type", ct)
|
||||||
|
} else {
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = c.Writer.Write(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage {
|
||||||
|
usage := gjson.GetBytes(body, "usage")
|
||||||
|
if !usage.Exists() || !usage.IsObject() {
|
||||||
|
return OpenAIUsage{}
|
||||||
|
}
|
||||||
|
inputTokens := firstPositiveGJSONInt(
|
||||||
|
usage.Get("prompt_tokens"),
|
||||||
|
usage.Get("input_tokens"),
|
||||||
|
usage.Get("total_tokens"),
|
||||||
|
)
|
||||||
|
outputTokens := firstPositiveGJSONInt(
|
||||||
|
usage.Get("completion_tokens"),
|
||||||
|
usage.Get("output_tokens"),
|
||||||
|
)
|
||||||
|
cacheReadTokens := firstPositiveGJSONInt(
|
||||||
|
usage.Get("prompt_tokens_details.cached_tokens"),
|
||||||
|
usage.Get("input_tokens_details.cached_tokens"),
|
||||||
|
usage.Get("cache_read_tokens"),
|
||||||
|
usage.Get("cache_read_input_tokens"),
|
||||||
|
)
|
||||||
|
cacheCreationTokens := firstPositiveGJSONInt(
|
||||||
|
usage.Get("cache_creation_tokens"),
|
||||||
|
usage.Get("cache_creation_input_tokens"),
|
||||||
|
usage.Get("input_tokens_details.cache_creation_tokens"),
|
||||||
|
)
|
||||||
|
return OpenAIUsage{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
CacheReadInputTokens: cacheReadTokens,
|
||||||
|
CacheCreationInputTokens: cacheCreationTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstPositiveGJSONInt(values ...gjson.Result) int {
|
||||||
|
for _, value := range values {
|
||||||
|
if !value.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n := int(value.Int())
|
||||||
|
if n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIEmbeddingsURL(base string) string {
|
||||||
|
return buildOpenAIEndpointURL(base, "/v1/embeddings")
|
||||||
|
}
|
||||||
106
backend/internal/service/openai_embeddings_test.go
Normal file
106
backend/internal/service/openai_embeddings_test.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpenAIEmbeddingsURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
base string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/embeddings"},
|
||||||
|
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/embeddings"},
|
||||||
|
{"already embeddings", "https://api.openai.com/v1/embeddings", "https://api.openai.com/v1/embeddings"},
|
||||||
|
{"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/embeddings"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Equal(t, tt.want, buildOpenAIEmbeddingsURL(tt.base))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardEmbeddings_APIKeyPassthroughRecordsUsageAndBatchInput(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
reqBody := []byte(`{
|
||||||
|
"model":"nowledge-embedding",
|
||||||
|
"input":["hello","world"],
|
||||||
|
"encoding_format":"float",
|
||||||
|
"dimensions":256
|
||||||
|
}`)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/embeddings", bytes.NewReader(reqBody))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"emb-rid"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{
|
||||||
|
"object":"list",
|
||||||
|
"data":[
|
||||||
|
{"object":"embedding","index":0,"embedding":[0.1,0.2]},
|
||||||
|
{"object":"embedding","index":1,"embedding":[0.3,0.4]}
|
||||||
|
],
|
||||||
|
"model":"jina-embeddings-v5-text-small",
|
||||||
|
"usage":{"prompt_tokens":13,"total_tokens":13}
|
||||||
|
}`)),
|
||||||
|
}}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"base_url": "https://api.jina.ai",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"nowledge-embedding": "jina-embeddings-v5-text-small",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardEmbeddings(context.Background(), c, account, reqBody, "")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "emb-rid", result.RequestID)
|
||||||
|
require.Equal(t, "nowledge-embedding", result.Model)
|
||||||
|
require.Equal(t, "jina-embeddings-v5-text-small", result.BillingModel)
|
||||||
|
require.Equal(t, "jina-embeddings-v5-text-small", result.UpstreamModel)
|
||||||
|
require.Equal(t, 13, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 0, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, "https://api.jina.ai/v1/embeddings", upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "jina-embeddings-v5-text-small", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||||
|
require.Equal(t, int64(2), gjson.GetBytes(upstream.lastBody, "input.#").Int())
|
||||||
|
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "input.0").String())
|
||||||
|
require.Equal(t, "world", gjson.GetBytes(upstream.lastBody, "input.1").String())
|
||||||
|
require.Equal(t, "float", gjson.GetBytes(upstream.lastBody, "encoding_format").String())
|
||||||
|
require.Equal(t, int64(256), gjson.GetBytes(upstream.lastBody, "dimensions").Int())
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user