sub2api/backend/internal/handler/openai_embeddings.go

248 lines
8.2 KiB
Go

package handler
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Embeddings handles the OpenAI-compatible Embeddings API.
// POST /v1/embeddings
func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.embeddings",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqLog = reqLog.With(zap.String("model", reqModel))
setOpsRequestContext(c, reqModel, false)
setOpsEndpointContext(c, "", int16(service.RequestTypeSync))
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.errorResponse(c, status, code, message)
return
}
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
switchCount := 0
maxAccountSwitches := h.maxAccountSwitches
if maxAccountSwitches <= 0 {
maxAccountSwitches = 3
}
routingStart := time.Now()
for {
selection, _, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
c.Request.Context(),
apiKey.GroupID,
"",
"",
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportHTTPSSE,
service.OpenAIEndpointCapabilityEmbeddings,
false,
)
if err != nil {
reqLog.Warn("openai_embeddings.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, false)
} else {
h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed")
}
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
if !accountAcquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "")
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, false)
return
}
switchCount++
reqLog.Warn("openai_embeddings.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
if c.Writer.Size() == writerSizeBeforeForward {
h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
}
reqLog.Warn("openai_embeddings.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.embeddings"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_embeddings.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_embeddings.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}