sub2api/backend/internal/service/openai_embeddings.go
2026-05-28 19:39:52 +08:00

241 lines
7.5 KiB
Go

package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
func (s *OpenAIGatewayService) ForwardEmbeddings(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
if originalModel == "" {
writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, fmt.Errorf("missing model in request")
}
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
upstreamBody := body
if upstreamModel != originalModel {
upstreamBody = ReplaceModelInBody(body, upstreamModel)
}
logger.L().Debug("openai embeddings: forwarding",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
)
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIEmbeddingsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
upstreamReq.Header.Set("Accept", "application/json")
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode),
}
}
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode)
}
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
return &OpenAIForwardResult{
RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")),
Usage: extractOpenAIEmbeddingsUsage(respBody),
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) {
if c == nil || resp == nil {
return
}
if c.Writer.Written() {
return
}
if resp.Header != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter)
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Writer.Header().Set("Content-Type", ct)
} else {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(body)
}
func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage {
usage := gjson.GetBytes(body, "usage")
if !usage.Exists() || !usage.IsObject() {
return OpenAIUsage{}
}
inputTokens := firstPositiveGJSONInt(
usage.Get("prompt_tokens"),
usage.Get("input_tokens"),
usage.Get("total_tokens"),
)
outputTokens := firstPositiveGJSONInt(
usage.Get("completion_tokens"),
usage.Get("output_tokens"),
)
cacheReadTokens := firstPositiveGJSONInt(
usage.Get("prompt_tokens_details.cached_tokens"),
usage.Get("input_tokens_details.cached_tokens"),
usage.Get("cache_read_tokens"),
usage.Get("cache_read_input_tokens"),
)
cacheCreationTokens := firstPositiveGJSONInt(
usage.Get("cache_creation_tokens"),
usage.Get("cache_creation_input_tokens"),
usage.Get("input_tokens_details.cache_creation_tokens"),
)
return OpenAIUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cacheReadTokens,
CacheCreationInputTokens: cacheCreationTokens,
}
}
func firstPositiveGJSONInt(values ...gjson.Result) int {
for _, value := range values {
if !value.Exists() {
continue
}
n := int(value.Int())
if n > 0 {
return n
}
}
return 0
}
func buildOpenAIEmbeddingsURL(base string) string {
return buildOpenAIEndpointURL(base, "/v1/embeddings")
}