sub2api/backend/internal/repository/group_repo_integration_test.go

938 lines
33 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build integration
package repository
import (
"context"
"database/sql"
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
tx *dbent.Tx
repo *groupRepository
}
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.tx = tx
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
}
func TestGroupRepoSuite(t *testing.T) {
suite.Run(t, new(GroupRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &service.Group{
Name: "test-create",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
err := s.repo.Create(s.ctx, group)
s.Require().NoError(err, "Create")
s.Require().NotZero(group.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{
Name: "original",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_PreservesMessagesDispatchModelConfig() {
group := &service.Group{
Name: "openai-dispatch",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
}
s.Require().NoError(s.repo.Create(s.ctx, group))
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
}
func (s *GroupRepoSuite) TestDelete() {
group := &service.Group{
Name: "to-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(groups, len(baseGroups)+2)
s.Require().Equal(basePage.Total+2, page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
baseGroups, _, err := s.repo.ListWithFilters(
s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform
for _, g := range groups {
s.Require().Equal(service.PlatformOpenAI, g.Platform)
}
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
g1 := &service.Group{
Name: "sort-g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "sort-g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g3 := &service.Group{
Name: "sort-g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
s.Require().NoError(s.repo.Create(s.ctx, g3))
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 30},
{ID: g2.ID, SortOrder: 10},
{ID: g3.ID, SortOrder: 20},
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
})
s.Require().NoError(err)
got1, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
got2, err := s.repo.GetByID(s.ctx, g2.ID)
s.Require().NoError(err)
got3, err := s.repo.GetByID(s.ctx, g3.ID)
s.Require().NoError(err)
s.Require().Equal(30, got1.SortOrder)
s.Require().Equal(15, got2.SortOrder)
s.Require().Equal(20, got3.SortOrder)
}
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
g1 := &service.Group{
Name: "sort-no-partial",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
before, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
beforeSort := before.SortOrder
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 99},
{ID: 99999999, SortOrder: 1},
})
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
after, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
s.Require().Equal(beforeSort, after.SortOrder)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
var accountID int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
s.Require().NoError(err)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
}
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
baseGroups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "active1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "inactive1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(groups, len(baseGroups)+1)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "active1" {
found = true
break
}
}
s.Require().True(found, "active1 group should be in results")
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
// 1 default anthropic group + 1 test active anthropic group = 2 total
s.Require().Len(groups, 2)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "g1" {
found = true
break
}
}
s.Require().True(found, "g1 group should be in results")
}
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "existing-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
s.Require().True(exists)
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := &service.Group{
Name: "g-count",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
var a1 int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
&a1,
))
var a2 int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
&a2,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := &service.Group{
Name: "g-empty",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。
// 当分组内存在 disabled 或 schedulable=false 的账号时ActiveAccountCount 必须小于 AccountCount
// 且与 GetAccountCount 返回的 active 值一致。
func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() {
g := &service.Group{
Name: "g-mixed-status",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
insertAccount := func(name, status string, schedulable bool) int64 {
var id int64
s.Require().NoError(scanSingleRow(
s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id",
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable},
&id,
))
return id
}
link := func(accountID int64, priority int) {
_, err := s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
accountID, g.ID, priority)
s.Require().NoError(err)
}
// account 1: active + schedulable → counts toward both total and active
link(insertAccount("acc-active-sched", service.StatusActive, true), 1)
// account 2: disabled → counts toward total only
link(insertAccount("acc-disabled", service.StatusDisabled, true), 2)
// account 3: active + not schedulable → counts toward total only
link(insertAccount("acc-unschedulable", service.StatusActive, false), 3)
// --- ListWithFilters path ---
isExclusive := false
groups, _, err := s.repo.ListWithFilters(s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 100},
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err)
var found *service.Group
for i := range groups {
if groups[i].ID == g.ID {
found = &groups[i]
break
}
}
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts")
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account")
// --- GetAccountCount must return identical values ---
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
}
// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。
// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池,
// 因此 ActiveAccountCount 必须与真实调度查询口径一致。
func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() {
g := &service.Group{
Name: "g-rate-limited",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var normalID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth},
&normalID))
var rateLimitedID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth},
&rateLimitedID))
var overloadedID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth},
&overloadedID))
var tempUnschedulableID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth},
&tempUnschedulableID))
var expiredID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id",
[]any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth},
&expiredID))
_, err := s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
normalID, g.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
rateLimitedID, g.ID, 2)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
overloadedID, g.ID, 3)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
tempUnschedulableID, g.ID, 4)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
expiredID, g.ID, 5)
s.Require().NoError(err)
isExclusive := false
groups, _, err := s.repo.ListWithFilters(s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 100},
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err)
var found *service.Group
for i := range groups {
if groups[i].ID == g.ID {
found = &groups[i]
break
}
}
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts")
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts")
s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts")
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
detail, err := s.repo.GetByID(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters")
s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters")
s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters")
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := &service.Group{
Name: "g-del",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var accountID int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := &service.Group{
Name: "g-multi",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
insertAccount := func(name string) int64 {
var id int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
&id,
))
return id
}
a1 := insertAccount("a1")
a2 := insertAccount("a2")
a3 := insertAccount("a3")
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}