fix(openai): record zero-cost usage for unpriced models
This commit is contained in:
parent
dbc8ae658c
commit
6d69ae87c3
@ -2,8 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -110,6 +110,10 @@ type CostBreakdown struct {
|
|||||||
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrModelPricingUnavailable indicates that none of the configured pricing
|
||||||
|
// sources can price the requested model.
|
||||||
|
var ErrModelPricingUnavailable = errors.New("pricing not found")
|
||||||
|
|
||||||
// BillingService 计费服务
|
// BillingService 计费服务
|
||||||
type BillingService struct {
|
type BillingService struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@ -355,7 +359,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
|||||||
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
return nil, fmt.Errorf("%w for model: %s", ErrModelPricingUnavailable, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||||
@ -452,7 +456,7 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
|
|||||||
|
|
||||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||||
if pricing == nil {
|
if pricing == nil {
|
||||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
return nil, fmt.Errorf("no pricing available for model: %s: %w", input.Model, ErrModelPricingUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||||
|
|||||||
@ -242,6 +242,57 @@ func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing
|
|||||||
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_MissingPricingRecordsZeroCostUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_missing_pricing",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 1200,
|
||||||
|
OutputTokens: 300,
|
||||||
|
},
|
||||||
|
Model: "deepseek-v4-flash",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1002, Quota: 100, Group: &Group{RateMultiplier: 1}},
|
||||||
|
User: &User{ID: 2002},
|
||||||
|
Account: &Account{ID: 3002, Type: AccountTypeAPIKey},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||||
|
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "resp_missing_pricing", usageRepo.lastLog.RequestID)
|
||||||
|
require.Equal(t, "deepseek-v4-flash", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, "deepseek-v4-flash", usageRepo.lastLog.RequestedModel)
|
||||||
|
require.Equal(t, 1200, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Equal(t, 300, usageRepo.lastLog.OutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, string(BillingModeToken), *usageRepo.lastLog.BillingMode)
|
||||||
|
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.BalanceCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
|
||||||
|
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||||
groupID := int64(11)
|
groupID := int64(11)
|
||||||
groupRate := 1.4
|
groupRate := 1.4
|
||||||
@ -1157,7 +1208,7 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpr
|
|||||||
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_UnpricedTokenModelFallsBackToZeroCostUsageLog(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
@ -1175,9 +1226,14 @@ func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePrice
|
|||||||
Account: &Account{ID: 30},
|
Account: &Account{ID: 30},
|
||||||
})
|
})
|
||||||
|
|
||||||
require.Error(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
require.Equal(t, 0, usageRepo.calls)
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "not-priceable-alias", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, 20, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Equal(t, 10, usageRepo.lastLog.OutputTokens)
|
||||||
|
require.Zero(t, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Zero(t, usageRepo.lastLog.ActualCost)
|
||||||
require.Equal(t, 0, userRepo.deductCalls)
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
require.Equal(t, 0, subRepo.incrementCalls)
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5273,7 +5273,19 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
}
|
}
|
||||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if !isUsagePricingUnavailableError(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "service.openai_gateway"),
|
||||||
|
zap.Strings("billing_models", billingModels),
|
||||||
|
zap.String("requested_model", input.OriginalModel),
|
||||||
|
zap.String("mapped_model", input.ChannelMappedModel),
|
||||||
|
zap.String("upstream_model", result.UpstreamModel),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Warn("openai_usage.pricing_missing_record_zero_cost", zap.Error(err))
|
||||||
|
cost = &CostBreakdown{BillingMode: string(BillingModeToken)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine billing type
|
// Determine billing type
|
||||||
@ -5439,6 +5451,17 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
|||||||
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUsagePricingUnavailableError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrModelPricingUnavailable) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "no pricing available") || strings.Contains(msg, "pricing not found")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
apiKey *APIKey,
|
apiKey *APIKey,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user