From ccace69d4e074929d2470b9fa995a9e9639b9b06 Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Thu, 28 May 2026 19:39:52 +0800 Subject: [PATCH] Add OpenAI embeddings gateway --- backend/internal/handler/endpoint.go | 5 +- backend/internal/handler/endpoint_test.go | 2 + backend/internal/handler/openai_embeddings.go | 253 ++++++++++++++++++ backend/internal/server/routes/gateway.go | 26 ++ backend/internal/service/openai_embeddings.go | 240 +++++++++++++++++ .../service/openai_embeddings_test.go | 106 ++++++++ 6 files changed, 631 insertions(+), 1 deletion(-) create mode 100644 backend/internal/handler/openai_embeddings.go create mode 100644 backend/internal/service/openai_embeddings.go create mode 100644 backend/internal/service/openai_embeddings_test.go diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index db29618a..0d6f4b3c 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -17,6 +17,7 @@ import ( const ( EndpointMessages = "/v1/messages" EndpointChatCompletions = "/v1/chat/completions" + EndpointEmbeddings = "/v1/embeddings" EndpointResponses = "/v1/responses" EndpointImagesGenerations = "/v1/images/generations" EndpointImagesEdits = "/v1/images/edits" @@ -42,6 +43,8 @@ const ( func NormalizeInboundEndpoint(path string) string { path = strings.TrimSpace(path) switch { + case strings.Contains(path, EndpointEmbeddings): + return EndpointEmbeddings case strings.Contains(path, EndpointChatCompletions): return EndpointChatCompletions case strings.Contains(path, EndpointMessages): @@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { switch platform { case service.PlatformOpenAI: - if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { + if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { return inbound } // OpenAI forwards everything to the Responses API. diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index 369c5fa7..42b6d6e7 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) { // Direct canonical paths. {"/v1/messages", EndpointMessages}, {"/v1/chat/completions", EndpointChatCompletions}, + {"/v1/embeddings", EndpointEmbeddings}, {"/v1/responses", EndpointResponses}, {"/v1/images/generations", EndpointImagesGenerations}, {"/v1/images/edits", EndpointImagesEdits}, @@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { {"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"}, {"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses}, {"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses}, + {"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings}, {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations}, {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits}, diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go new file mode 100644 index 00000000..bbb67044 --- /dev/null +++ b/backend/internal/handler/openai_embeddings.go @@ -0,0 +1,253 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "strconv" + "strings" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// Embeddings handles the OpenAI-compatible Embeddings API. +// POST /v1/embeddings +func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { + streamStarted := false + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.embeddings", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqLog = reqLog.With(zap.String("model", reqModel)) + setOpsRequestContext(c, reqModel, false) + setOpsEndpointContext(c, "", int16(service.RequestTypeSync)) + + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { + reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err)) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + h.errorResponse(c, status, code, message) + return + } + + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError + switchCount := 0 + maxAccountSwitches := h.maxAccountSwitches + if maxAccountSwitches <= 0 { + maxAccountSwitches = 3 + } + routingStart := time.Now() + + for { + selection, _, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + "", + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportHTTPSSE, + false, + ) + if err != nil { + reqLog.Warn("openai_embeddings.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") + return + } + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, false) + } else { + h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed") + } + return + } + if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + account := selection.Account + if account.Type != service.AccountTypeAPIKey { + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + failedAccountIDs[account.ID] = struct{}{} + continue + } + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog) + if !accountAcquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + writerSizeBeforeForward := c.Writer.Size() + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "") + }() + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, false) + return + } + switchCount++ + reqLog.Warn("openai_embeddings.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if c.Writer.Size() == writerSizeBeforeForward { + h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + } + reqLog.Warn("openai_embeddings.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.embeddings"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_embeddings.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_embeddings.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index efc0687f..b039a6ec 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -89,6 +89,19 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + gateway.POST("/embeddings", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Embeddings API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Embeddings(c) + }) gateway.POST("/images/generations", func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) @@ -158,6 +171,19 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + r.POST("/embeddings", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Embeddings API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Embeddings(c) + }) r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) diff --git a/backend/internal/service/openai_embeddings.go b/backend/internal/service/openai_embeddings.go new file mode 100644 index 00000000..359df3bb --- /dev/null +++ b/backend/internal/service/openai_embeddings.go @@ -0,0 +1,240 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +func (s *OpenAIGatewayService) ForwardEmbeddings( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) + if originalModel == "" { + writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, fmt.Errorf("missing model in request") + } + + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + upstreamBody := body + if upstreamModel != originalModel { + upstreamBody = ReplaceModelInBody(body, upstreamModel) + } + + logger.L().Debug("openai embeddings: forwarding", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + ) + + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + targetURL := buildOpenAIEmbeddingsURL(validatedURL) + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + upstreamReq.Header.Set("Accept", "application/json") + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiCCRawAllowedHeaders[lowerKey] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), + } + } + writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) + return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode) + } + + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) + + return &OpenAIForwardResult{ + RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")), + Usage: extractOpenAIEmbeddingsUsage(respBody), + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) { + if c == nil || resp == nil { + return + } + if c.Writer.Written() { + return + } + if resp.Header != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter) + } + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Writer.Header().Set("Content-Type", ct) + } else { + c.Writer.Header().Set("Content-Type", "application/json") + } + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(body) +} + +func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage { + usage := gjson.GetBytes(body, "usage") + if !usage.Exists() || !usage.IsObject() { + return OpenAIUsage{} + } + inputTokens := firstPositiveGJSONInt( + usage.Get("prompt_tokens"), + usage.Get("input_tokens"), + usage.Get("total_tokens"), + ) + outputTokens := firstPositiveGJSONInt( + usage.Get("completion_tokens"), + usage.Get("output_tokens"), + ) + cacheReadTokens := firstPositiveGJSONInt( + usage.Get("prompt_tokens_details.cached_tokens"), + usage.Get("input_tokens_details.cached_tokens"), + usage.Get("cache_read_tokens"), + usage.Get("cache_read_input_tokens"), + ) + cacheCreationTokens := firstPositiveGJSONInt( + usage.Get("cache_creation_tokens"), + usage.Get("cache_creation_input_tokens"), + usage.Get("input_tokens_details.cache_creation_tokens"), + ) + return OpenAIUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cacheReadTokens, + CacheCreationInputTokens: cacheCreationTokens, + } +} + +func firstPositiveGJSONInt(values ...gjson.Result) int { + for _, value := range values { + if !value.Exists() { + continue + } + n := int(value.Int()) + if n > 0 { + return n + } + } + return 0 +} + +func buildOpenAIEmbeddingsURL(base string) string { + return buildOpenAIEndpointURL(base, "/v1/embeddings") +} diff --git a/backend/internal/service/openai_embeddings_test.go b/backend/internal/service/openai_embeddings_test.go new file mode 100644 index 00000000..c7e89d64 --- /dev/null +++ b/backend/internal/service/openai_embeddings_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildOpenAIEmbeddingsURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + want string + }{ + {"bare domain", "https://api.openai.com", "https://api.openai.com/v1/embeddings"}, + {"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/embeddings"}, + {"already embeddings", "https://api.openai.com/v1/embeddings", "https://api.openai.com/v1/embeddings"}, + {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/embeddings"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, buildOpenAIEmbeddingsURL(tt.base)) + }) + } +} + +func TestForwardEmbeddings_APIKeyPassthroughRecordsUsageAndBatchInput(t *testing.T) { + gin.SetMode(gin.TestMode) + + reqBody := []byte(`{ + "model":"nowledge-embedding", + "input":["hello","world"], + "encoding_format":"float", + "dimensions":256 + }`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/embeddings", bytes.NewReader(reqBody)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"emb-rid"}, + }, + Body: io.NopCloser(strings.NewReader(`{ + "object":"list", + "data":[ + {"object":"embedding","index":0,"embedding":[0.1,0.2]}, + {"object":"embedding","index":1,"embedding":[0.3,0.4]} + ], + "model":"jina-embeddings-v5-text-small", + "usage":{"prompt_tokens":13,"total_tokens":13} + }`)), + }} + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + } + account := &Account{ + ID: 42, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.jina.ai", + "model_mapping": map[string]any{ + "nowledge-embedding": "jina-embeddings-v5-text-small", + }, + }, + } + + result, err := svc.ForwardEmbeddings(context.Background(), c, account, reqBody, "") + + require.NoError(t, err) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, result) + require.Equal(t, "emb-rid", result.RequestID) + require.Equal(t, "nowledge-embedding", result.Model) + require.Equal(t, "jina-embeddings-v5-text-small", result.BillingModel) + require.Equal(t, "jina-embeddings-v5-text-small", result.UpstreamModel) + require.Equal(t, 13, result.Usage.InputTokens) + require.Equal(t, 0, result.Usage.OutputTokens) + require.Equal(t, "https://api.jina.ai/v1/embeddings", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "jina-embeddings-v5-text-small", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, int64(2), gjson.GetBytes(upstream.lastBody, "input.#").Int()) + require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "input.0").String()) + require.Equal(t, "world", gjson.GetBytes(upstream.lastBody, "input.1").String()) + require.Equal(t, "float", gjson.GetBytes(upstream.lastBody, "encoding_format").String()) + require.Equal(t, int64(256), gjson.GetBytes(upstream.lastBody, "dimensions").Int()) +}