sub2api/backend/internal/service/channel_available_test.go
name c26d3ae1b5 feat(channels): 渠道未填价时按 LiteLLM 默认价展示
「可用渠道」展示链路有两个未覆盖场景导致用户看到"未配置定价":

1. admin 在 UI 里建了 ModelPricing 条目但没填任何价格 (常见于
   per_request / image 模式只填了 tier_label 没填单价): 原 fallback
   只检查 Pricing == nil, 这种空条目会跳过 LiteLLM 兜底。
2. LiteLLM 把图片模型标记 mode=image_generation, 但合成器固定按
   token 模式合成, 把 OutputCostPerImage / 图片 token 价丢到错误字段。

改动 (仅 backend/internal/service/channel_available.go):
- 新增 pricingNeedsFallback: 价格字段全空 (含 intervals 全空) 视为
  未配置, 触发 LiteLLM 兜底。
- synthesizePricingFromLiteLLM 加 existing 参数: 优先尊重渠道已选
  BillingMode (per_request / image 也按此模式合成), 没选才看 LiteLLM
  mode, 仍未命中默认 token。
- image / per_request 分支用 OutputCostPerImage 填 PerRequestPrice,
  OutputCostPerImageToken 填 ImageOutputPrice, 让 gpt-image / dall-e
  系列展示出参考价。

仅影响展示链路, 真实计费走 BillingService / ModelPricingResolver
完全不受影响。新增 8 个单元测试覆盖 pricingNeedsFallback 各分支、
合成器三种模式选择、空条目兜底与既有价格保护。
2026-05-15 01:28:13 +08:00

312 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub
// 仅实现 ListActive其他方法对本测试无关返回零值即可。
// listActiveErr 非 nil 时ListActive 返回该错误用于错误传播测试。
// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
type stubGroupRepoForAvailable struct {
activeGroups []Group
listActiveErr error
listActiveCalls int
}
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
s.listActiveCalls++
if s.listActiveErr != nil {
return nil, s.listActiveErr
}
return s.activeGroups, nil
}
func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
// newAvailableChannelService 构造一个 ChannelServicechannelRepo.ListAll 返回给定 channels
// groupRepo 由参数决定。传入空 stub 表示「活跃分组列表为空」。
func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
}
return NewChannelService(repo, groupRepo, nil, nil)
}
func TestListAvailable_EmptyActiveGroups_NoGroupsAttached(t *testing.T) {
// 活跃分组列表为空时,渠道的 Groups 应为空切片,不报错。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{10, 20},
}}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Empty(t, out[0].Groups)
}
func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
// 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{1, 99},
}}
groupRepo := &stubGroupRepoForAvailable{
activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
}
svc := newAvailableChannelService(channels, groupRepo)
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Len(t, out[0].Groups, 1)
require.Equal(t, int64(1), out[0].Groups[0].ID)
}
func TestListAvailable_SortedByName(t *testing.T) {
channels := []Channel{
{ID: 1, Name: "beta"},
{ID: 2, Name: "Alpha"},
{ID: 3, Name: "charlie"},
}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 3)
require.Equal(t, "Alpha", out[0].Name)
require.Equal(t, "beta", out[1].Name)
require.Equal(t, "charlie", out[2].Name)
}
func TestListAvailable_ListAllErrorPropagates(t *testing.T) {
// ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo短路
sentinel := errors.New("list-all-boom")
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel },
}
groupRepo := &stubGroupRepoForAvailable{}
svc := NewChannelService(repo, groupRepo, nil, nil)
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v")
require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive")
}
func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
// groupRepo.ListActive 返回错误时 ListAvailable 应直接返回包装后的错误。
sentinel := errors.New("list-active-boom")
svc := newAvailableChannelService(
[]Channel{{ID: 1, Name: "chA"}},
&stubGroupRepoForAvailable{listActiveErr: sentinel},
)
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v")
}
func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
// 渠道 BillingModelSource 为空时应回填为 BillingModelSourceChannelMapped
// 显式值应原样保留(由 service 层统一处理,避免各 handler 重复默认逻辑)。
channels := []Channel{
{ID: 1, Name: "empty", BillingModelSource: ""},
{ID: 2, Name: "explicit", BillingModelSource: BillingModelSourceUpstream},
}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 2)
// 按 Name 查找,避免依赖排序副作用。
byName := make(map[string]string, len(out))
for _, ch := range out {
byName[ch.Name] = ch.BillingModelSource
}
require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
}
func TestPricingNeedsFallback(t *testing.T) {
tests := []struct {
name string
in *ChannelModelPricing
want bool
}{
{"nil", nil, true},
{"empty struct", &ChannelModelPricing{BillingMode: BillingModeToken}, true},
{"all-empty intervals", &ChannelModelPricing{
BillingMode: BillingModeImage,
Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}},
}, true},
{"flat input set", &ChannelModelPricing{InputPrice: testPtrFloat64(3e-6)}, false},
{"flat per_request set", &ChannelModelPricing{PerRequestPrice: testPtrFloat64(0.04)}, false},
{"interval with price", &ChannelModelPricing{
Intervals: []PricingInterval{{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}},
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, pricingNeedsFallback(tt.in))
})
}
}
func TestSynthesizePricingFromLiteLLM_TokenMode(t *testing.T) {
lp := &LiteLLMModelPricing{
Mode: "chat",
InputCostPerToken: 3e-6,
OutputCostPerToken: 1.5e-5,
CacheCreationInputTokenCost: 3.75e-6,
CacheReadInputTokenCost: 3e-7,
}
got := synthesizePricingFromLiteLLM(lp, nil)
require.NotNil(t, got)
require.Equal(t, BillingModeToken, got.BillingMode)
require.NotNil(t, got.InputPrice)
require.InDelta(t, 3e-6, *got.InputPrice, 1e-12)
require.NotNil(t, got.CacheReadPrice)
}
func TestSynthesizePricingFromLiteLLM_ImageGenerationMode(t *testing.T) {
// LiteLLM mode=image_generation 且渠道未声明模式时,按 image 合成。
lp := &LiteLLMModelPricing{
Mode: "image_generation",
OutputCostPerImageToken: 4e-5,
}
got := synthesizePricingFromLiteLLM(lp, nil)
require.NotNil(t, got)
require.Equal(t, BillingModeImage, got.BillingMode)
require.Nil(t, got.PerRequestPrice)
require.NotNil(t, got.ImageOutputPrice)
}
func TestSynthesizePricingFromLiteLLM_RespectsExistingChannelMode(t *testing.T) {
// admin UI 选了 per_request 但没填价LiteLLM 数据按 per_request 合成,
// 即便 LiteLLM 标的是 chat 模式也尊重渠道选择。
lp := &LiteLLMModelPricing{
Mode: "chat",
InputCostPerToken: 5e-6,
OutputCostPerImage: 0.04,
}
existing := &ChannelModelPricing{BillingMode: BillingModePerRequest}
got := synthesizePricingFromLiteLLM(lp, existing)
require.NotNil(t, got)
require.Equal(t, BillingModePerRequest, got.BillingMode)
require.NotNil(t, got.PerRequestPrice)
require.InDelta(t, 0.04, *got.PerRequestPrice, 1e-12)
}
func TestFillGlobalPricingFallback_NilPricing(t *testing.T) {
pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
"claude-opus-4-5": {Mode: "chat", InputCostPerToken: 5e-6},
})
svc := &ChannelService{pricingService: pricingSvc}
models := []SupportedModel{
{Name: "claude-opus-4-5", Platform: "anthropic"},
}
svc.fillGlobalPricingFallback(models)
require.NotNil(t, models[0].Pricing)
require.NotNil(t, models[0].Pricing.InputPrice)
require.InDelta(t, 5e-6, *models[0].Pricing.InputPrice, 1e-12)
}
func TestFillGlobalPricingFallback_EmptyPricingFillsFromLiteLLM(t *testing.T) {
// 核心场景admin UI 建了 pricing 条目image 模式)但没填价,应走 LiteLLM 兜底。
pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
"gpt-image-1": {
Mode: "image_generation",
OutputCostPerImageToken: 4e-5,
},
})
svc := &ChannelService{pricingService: pricingSvc}
models := []SupportedModel{
{
Name: "gpt-image-1",
Platform: "openai",
Pricing: &ChannelModelPricing{
BillingMode: BillingModeImage,
Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}},
},
},
}
svc.fillGlobalPricingFallback(models)
require.NotNil(t, models[0].Pricing)
require.Equal(t, BillingModeImage, models[0].Pricing.BillingMode)
require.NotNil(t, models[0].Pricing.ImageOutputPrice)
require.InDelta(t, 4e-5, *models[0].Pricing.ImageOutputPrice, 1e-12)
}
func TestFillGlobalPricingFallback_KeepsExistingPrice(t *testing.T) {
// 渠道已经填了价格的条目不应被回落覆盖。
pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
"served-model": {Mode: "chat", InputCostPerToken: 1e-6},
})
svc := &ChannelService{pricingService: pricingSvc}
existing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(9e-9),
}
models := []SupportedModel{
{Name: "served-model", Platform: "anthropic", Pricing: existing},
}
svc.fillGlobalPricingFallback(models)
require.Same(t, existing, models[0].Pricing)
}
func newStubPricingServiceFromMap(data map[string]*LiteLLMModelPricing) *PricingService {
return &PricingService{pricingData: data}
}