package service import ( "bytes" "context" "errors" "fmt" "io" "net/http" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" ) func (s *OpenAIGatewayService) ForwardEmbeddings( ctx context.Context, c *gin.Context, account *Account, body []byte, defaultMappedModel string, ) (*OpenAIForwardResult, error) { startTime := time.Now() originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) if originalModel == "" { writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") return nil, fmt.Errorf("missing model in request") } billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) upstreamBody := body if upstreamModel != originalModel { upstreamBody = ReplaceModelInBody(body, upstreamModel) } logger.L().Debug("openai embeddings: forwarding", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), zap.String("billing_model", billingModel), zap.String("upstream_model", upstreamModel), ) apiKey := account.GetOpenAIApiKey() if apiKey == "" { return nil, fmt.Errorf("account %d missing api_key", account.ID) } baseURL := account.GetOpenAIBaseURL() if baseURL == "" { baseURL = "https://api.openai.com" } validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, fmt.Errorf("invalid base_url: %w", err) } targetURL := buildOpenAIEmbeddingsURL(validatedURL) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) upstreamReq.Header.Set("Accept", "application/json") for key, values := range c.Request.Header { lowerKey := strings.ToLower(key) if openaiCCRawAllowedHeaders[lowerKey] { for _, v := range values { upstreamReq.Header.Add(key, v) } } } if customUA := account.GetOpenAIUserAgent(); customUA != "" { upstreamReq.Header.Set("user-agent", customUA) } proxyURL := "" if account.Proxy != nil { proxyURL = account.Proxy.URL() } resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, Kind: "request_error", Message: safeErr, }) writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") return nil, fmt.Errorf("upstream request failed: %s", safeErr) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "failover", Message: upstreamMsg, Detail: upstreamDetail, }) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode) } respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") } return nil, fmt.Errorf("read upstream body: %w", err) } writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) return &OpenAIForwardResult{ RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")), Usage: extractOpenAIEmbeddingsUsage(respBody), Model: originalModel, BillingModel: billingModel, UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil } func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) { if c == nil || resp == nil { return } if c.Writer.Written() { return } if resp.Header != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter) } if ct := resp.Header.Get("Content-Type"); ct != "" { c.Writer.Header().Set("Content-Type", ct) } else { c.Writer.Header().Set("Content-Type", "application/json") } c.Writer.WriteHeader(resp.StatusCode) _, _ = c.Writer.Write(body) } func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ "type": errType, "message": message, }, }) } func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage { usage := gjson.GetBytes(body, "usage") if !usage.Exists() || !usage.IsObject() { return OpenAIUsage{} } inputTokens := firstPositiveGJSONInt( usage.Get("prompt_tokens"), usage.Get("input_tokens"), usage.Get("total_tokens"), ) outputTokens := firstPositiveGJSONInt( usage.Get("completion_tokens"), usage.Get("output_tokens"), ) cacheReadTokens := firstPositiveGJSONInt( usage.Get("prompt_tokens_details.cached_tokens"), usage.Get("input_tokens_details.cached_tokens"), usage.Get("cache_read_tokens"), usage.Get("cache_read_input_tokens"), ) cacheCreationTokens := firstPositiveGJSONInt( usage.Get("cache_creation_tokens"), usage.Get("cache_creation_input_tokens"), usage.Get("input_tokens_details.cache_creation_tokens"), ) return OpenAIUsage{ InputTokens: inputTokens, OutputTokens: outputTokens, CacheReadInputTokens: cacheReadTokens, CacheCreationInputTokens: cacheCreationTokens, } } func firstPositiveGJSONInt(values ...gjson.Result) int { for _, value := range values { if !value.Exists() { continue } n := int(value.Int()) if n > 0 { return n } } return 0 } func buildOpenAIEmbeddingsURL(base string) string { return buildOpenAIEndpointURL(base, "/v1/embeddings") }