241 lines
7.5 KiB
Go
241 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func (s *OpenAIGatewayService) ForwardEmbeddings(
|
|
ctx context.Context,
|
|
c *gin.Context,
|
|
account *Account,
|
|
body []byte,
|
|
defaultMappedModel string,
|
|
) (*OpenAIForwardResult, error) {
|
|
startTime := time.Now()
|
|
|
|
originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
|
if originalModel == "" {
|
|
writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
|
return nil, fmt.Errorf("missing model in request")
|
|
}
|
|
|
|
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
|
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
|
upstreamBody := body
|
|
if upstreamModel != originalModel {
|
|
upstreamBody = ReplaceModelInBody(body, upstreamModel)
|
|
}
|
|
|
|
logger.L().Debug("openai embeddings: forwarding",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.String("original_model", originalModel),
|
|
zap.String("billing_model", billingModel),
|
|
zap.String("upstream_model", upstreamModel),
|
|
)
|
|
|
|
apiKey := account.GetOpenAIApiKey()
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
|
}
|
|
baseURL := account.GetOpenAIBaseURL()
|
|
if baseURL == "" {
|
|
baseURL = "https://api.openai.com"
|
|
}
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid base_url: %w", err)
|
|
}
|
|
targetURL := buildOpenAIEmbeddingsURL(validatedURL)
|
|
|
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
|
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
|
releaseUpstreamCtx()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
|
}
|
|
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
|
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
|
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
|
upstreamReq.Header.Set("Accept", "application/json")
|
|
for key, values := range c.Request.Header {
|
|
lowerKey := strings.ToLower(key)
|
|
if openaiCCRawAllowedHeaders[lowerKey] {
|
|
for _, v := range values {
|
|
upstreamReq.Header.Add(key, v)
|
|
}
|
|
}
|
|
}
|
|
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
|
|
upstreamReq.Header.Set("user-agent", customUA)
|
|
}
|
|
|
|
proxyURL := ""
|
|
if account.Proxy != nil {
|
|
proxyURL = account.Proxy.URL()
|
|
}
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|
if err != nil {
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: 0,
|
|
Kind: "request_error",
|
|
Message: safeErr,
|
|
})
|
|
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
_ = resp.Body.Close()
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
|
upstreamDetail := ""
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
if maxBytes <= 0 {
|
|
maxBytes = 2048
|
|
}
|
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
}
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
Kind: "failover",
|
|
Message: upstreamMsg,
|
|
Detail: upstreamDetail,
|
|
})
|
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
|
return nil, &UpstreamFailoverError{
|
|
StatusCode: resp.StatusCode,
|
|
ResponseBody: respBody,
|
|
RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode),
|
|
}
|
|
}
|
|
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
|
|
return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
|
if err != nil {
|
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
|
}
|
|
return nil, fmt.Errorf("read upstream body: %w", err)
|
|
}
|
|
|
|
writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter)
|
|
|
|
return &OpenAIForwardResult{
|
|
RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")),
|
|
Usage: extractOpenAIEmbeddingsUsage(respBody),
|
|
Model: originalModel,
|
|
BillingModel: billingModel,
|
|
UpstreamModel: upstreamModel,
|
|
Stream: false,
|
|
Duration: time.Since(startTime),
|
|
}, nil
|
|
}
|
|
|
|
func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) {
|
|
if c == nil || resp == nil {
|
|
return
|
|
}
|
|
if c.Writer.Written() {
|
|
return
|
|
}
|
|
if resp.Header != nil {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter)
|
|
}
|
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
|
c.Writer.Header().Set("Content-Type", ct)
|
|
} else {
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, _ = c.Writer.Write(body)
|
|
}
|
|
|
|
func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) {
|
|
c.JSON(statusCode, gin.H{
|
|
"error": gin.H{
|
|
"type": errType,
|
|
"message": message,
|
|
},
|
|
})
|
|
}
|
|
|
|
func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage {
|
|
usage := gjson.GetBytes(body, "usage")
|
|
if !usage.Exists() || !usage.IsObject() {
|
|
return OpenAIUsage{}
|
|
}
|
|
inputTokens := firstPositiveGJSONInt(
|
|
usage.Get("prompt_tokens"),
|
|
usage.Get("input_tokens"),
|
|
usage.Get("total_tokens"),
|
|
)
|
|
outputTokens := firstPositiveGJSONInt(
|
|
usage.Get("completion_tokens"),
|
|
usage.Get("output_tokens"),
|
|
)
|
|
cacheReadTokens := firstPositiveGJSONInt(
|
|
usage.Get("prompt_tokens_details.cached_tokens"),
|
|
usage.Get("input_tokens_details.cached_tokens"),
|
|
usage.Get("cache_read_tokens"),
|
|
usage.Get("cache_read_input_tokens"),
|
|
)
|
|
cacheCreationTokens := firstPositiveGJSONInt(
|
|
usage.Get("cache_creation_tokens"),
|
|
usage.Get("cache_creation_input_tokens"),
|
|
usage.Get("input_tokens_details.cache_creation_tokens"),
|
|
)
|
|
return OpenAIUsage{
|
|
InputTokens: inputTokens,
|
|
OutputTokens: outputTokens,
|
|
CacheReadInputTokens: cacheReadTokens,
|
|
CacheCreationInputTokens: cacheCreationTokens,
|
|
}
|
|
}
|
|
|
|
func firstPositiveGJSONInt(values ...gjson.Result) int {
|
|
for _, value := range values {
|
|
if !value.Exists() {
|
|
continue
|
|
}
|
|
n := int(value.Int())
|
|
if n > 0 {
|
|
return n
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func buildOpenAIEmbeddingsURL(base string) string {
|
|
return buildOpenAIEndpointURL(base, "/v1/embeddings")
|
|
}
|