Add OpenAI embeddings gateway

This commit is contained in:
Wey Gu 2026-05-28 19:39:52 +08:00
parent 89d96f4b25
commit ccace69d4e
6 changed files with 631 additions and 1 deletions

View File

@ -17,6 +17,7 @@ import (
const (
EndpointMessages = "/v1/messages"
EndpointChatCompletions = "/v1/chat/completions"
EndpointEmbeddings = "/v1/embeddings"
EndpointResponses = "/v1/responses"
EndpointImagesGenerations = "/v1/images/generations"
EndpointImagesEdits = "/v1/images/edits"
@ -42,6 +43,8 @@ const (
func NormalizeInboundEndpoint(path string) string {
path = strings.TrimSpace(path)
switch {
case strings.Contains(path, EndpointEmbeddings):
return EndpointEmbeddings
case strings.Contains(path, EndpointChatCompletions):
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
switch platform {
case service.PlatformOpenAI:
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
return inbound
}
// OpenAI forwards everything to the Responses API.

View File

@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
// Direct canonical paths.
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/embeddings", EndpointEmbeddings},
{"/v1/responses", EndpointResponses},
{"/v1/images/generations", EndpointImagesGenerations},
{"/v1/images/edits", EndpointImagesEdits},
@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
{"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings},
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},

View File

@ -0,0 +1,253 @@
package handler
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Embeddings handles the OpenAI-compatible Embeddings API.
// POST /v1/embeddings
func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.embeddings",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqLog = reqLog.With(zap.String("model", reqModel))
setOpsRequestContext(c, reqModel, false)
setOpsEndpointContext(c, "", int16(service.RequestTypeSync))
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.errorResponse(c, status, code, message)
return
}
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
switchCount := 0
maxAccountSwitches := h.maxAccountSwitches
if maxAccountSwitches <= 0 {
maxAccountSwitches = 3
}
routingStart := time.Now()
for {
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
"",
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportHTTPSSE,
false,
)
if err != nil {
reqLog.Warn("openai_embeddings.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, false)
} else {
h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed")
}
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
account := selection.Account
if account.Type != service.AccountTypeAPIKey {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
failedAccountIDs[account.ID] = struct{}{}
continue
}
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
if !accountAcquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "")
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, false)
return
}
switchCount++
reqLog.Warn("openai_embeddings.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
if c.Writer.Size() == writerSizeBeforeForward {
h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
}
reqLog.Warn("openai_embeddings.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.embeddings"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_embeddings.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_embeddings.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}

View File

@ -89,6 +89,19 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
gateway.POST("/embeddings", func(c *gin.Context) {
if getGroupPlatform(c) != service.PlatformOpenAI {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"type": "not_found_error",
"message": "Embeddings API is not supported for this platform",
},
})
return
}
h.OpenAIGateway.Embeddings(c)
})
gateway.POST("/images/generations", func(c *gin.Context) {
if getGroupPlatform(c) != service.PlatformOpenAI {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
@ -158,6 +171,19 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
r.POST("/embeddings", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) != service.PlatformOpenAI {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)
c.JSON(http.StatusNotFound, gin.H{
"error": gin.H{
"type": "not_found_error",
"message": "Embeddings API is not supported for this platform",
},
})
return
}
h.OpenAIGateway.Embeddings(c)
})
r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) != service.PlatformOpenAI {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate)

View File

@ -0,0 +1,240 @@
package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
func (s *OpenAIGatewayService) ForwardEmbeddings(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
if originalModel == "" {
writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, fmt.Errorf("missing model in request")
}
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
upstreamBody := body
if upstreamModel != originalModel {
upstreamBody = ReplaceModelInBody(body, upstreamModel)
}
logger.L().Debug("openai embeddings: forwarding",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
)
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIEmbeddingsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
upstreamReq.Header.Set("Accept", "application/json")
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode),
}
}
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode)
}
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
return &OpenAIForwardResult{
RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")),
Usage: extractOpenAIEmbeddingsUsage(respBody),
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) {
if c == nil || resp == nil {
return
}
if c.Writer.Written() {
return
}
if resp.Header != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter)
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Writer.Header().Set("Content-Type", ct)
} else {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(body)
}
func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage {
usage := gjson.GetBytes(body, "usage")
if !usage.Exists() || !usage.IsObject() {
return OpenAIUsage{}
}
inputTokens := firstPositiveGJSONInt(
usage.Get("prompt_tokens"),
usage.Get("input_tokens"),
usage.Get("total_tokens"),
)
outputTokens := firstPositiveGJSONInt(
usage.Get("completion_tokens"),
usage.Get("output_tokens"),
)
cacheReadTokens := firstPositiveGJSONInt(
usage.Get("prompt_tokens_details.cached_tokens"),
usage.Get("input_tokens_details.cached_tokens"),
usage.Get("cache_read_tokens"),
usage.Get("cache_read_input_tokens"),
)
cacheCreationTokens := firstPositiveGJSONInt(
usage.Get("cache_creation_tokens"),
usage.Get("cache_creation_input_tokens"),
usage.Get("input_tokens_details.cache_creation_tokens"),
)
return OpenAIUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cacheReadTokens,
CacheCreationInputTokens: cacheCreationTokens,
}
}
func firstPositiveGJSONInt(values ...gjson.Result) int {
for _, value := range values {
if !value.Exists() {
continue
}
n := int(value.Int())
if n > 0 {
return n
}
}
return 0
}
func buildOpenAIEmbeddingsURL(base string) string {
return buildOpenAIEndpointURL(base, "/v1/embeddings")
}

View File

@ -0,0 +1,106 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildOpenAIEmbeddingsURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/embeddings"},
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/embeddings"},
{"already embeddings", "https://api.openai.com/v1/embeddings", "https://api.openai.com/v1/embeddings"},
{"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/embeddings"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, buildOpenAIEmbeddingsURL(tt.base))
})
}
}
func TestForwardEmbeddings_APIKeyPassthroughRecordsUsageAndBatchInput(t *testing.T) {
gin.SetMode(gin.TestMode)
reqBody := []byte(`{
"model":"nowledge-embedding",
"input":["hello","world"],
"encoding_format":"float",
"dimensions":256
}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/embeddings", bytes.NewReader(reqBody))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"emb-rid"},
},
Body: io.NopCloser(strings.NewReader(`{
"object":"list",
"data":[
{"object":"embedding","index":0,"embedding":[0.1,0.2]},
{"object":"embedding","index":1,"embedding":[0.3,0.4]}
],
"model":"jina-embeddings-v5-text-small",
"usage":{"prompt_tokens":13,"total_tokens":13}
}`)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: upstream,
}
account := &Account{
ID: 42,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.jina.ai",
"model_mapping": map[string]any{
"nowledge-embedding": "jina-embeddings-v5-text-small",
},
},
}
result, err := svc.ForwardEmbeddings(context.Background(), c, account, reqBody, "")
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, result)
require.Equal(t, "emb-rid", result.RequestID)
require.Equal(t, "nowledge-embedding", result.Model)
require.Equal(t, "jina-embeddings-v5-text-small", result.BillingModel)
require.Equal(t, "jina-embeddings-v5-text-small", result.UpstreamModel)
require.Equal(t, 13, result.Usage.InputTokens)
require.Equal(t, 0, result.Usage.OutputTokens)
require.Equal(t, "https://api.jina.ai/v1/embeddings", upstream.lastReq.URL.String())
require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "jina-embeddings-v5-text-small", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, int64(2), gjson.GetBytes(upstream.lastBody, "input.#").Int())
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "input.0").String())
require.Equal(t, "world", gjson.GetBytes(upstream.lastBody, "input.1").String())
require.Equal(t, "float", gjson.GetBytes(upstream.lastBody, "encoding_format").String())
require.Equal(t, int64(256), gjson.GetBytes(upstream.lastBody, "dimensions").Int())
}