From df2b02e61c84d87bf9725135ec2b4830c7ce3092 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 20 May 2026 16:53:23 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=E5=88=86=E7=BB=84?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E5=8F=AF=E7=94=A8=E8=AE=A1=E6=95=B0=E5=8F=A3?= =?UTF-8?q?=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/repository/group_repo.go | 59 ++++++++++++------- .../repository/group_repo_integration_test.go | 54 ++++++++++++++--- 2 files changed, 86 insertions(+), 27 deletions(-) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9b6377bc..66b2316a 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -94,9 +94,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - total, active, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = total - out.ActiveAccountCount = active + counts, err := r.loadAccountCounts(ctx, []int64{out.ID}) + if err == nil { + c := counts[out.ID] + out.AccountCount = c.Total + out.ActiveAccountCount = c.Active + out.RateLimitedAccountCount = c.RateLimited + } return out, nil } @@ -538,15 +542,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { var rateLimited int64 err = scanSingleRow(ctx, r.sql, - `SELECT COUNT(*), - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) + fmt.Sprintf(`SELECT + COUNT(*) FILTER (WHERE a.deleted_at IS NULL), + COUNT(*) FILTER (WHERE %s), + COUNT(*) FILTER (WHERE %s) FROM account_groups ag JOIN accounts a ON a.id = ag.account_id - WHERE ag.group_id = $1`, + WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), []any{groupID}, &total, &active, &rateLimited) return } @@ -680,6 +681,28 @@ type groupAccountCounts struct { RateLimited int64 } +const ( + // 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。 + groupAccountAvailableSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW()) + AND (a.overload_until IS NULL OR a.overload_until <= NOW()) + AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())` + + // 这里沿用历史字段名 RateLimitedAccountCount,但统计的是会让账号暂时退出调度的时间窗口。 + groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )` +) + func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { @@ -688,18 +711,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 rows, err := r.sql.QueryContext( ctx, - `SELECT ag.group_id, - COUNT(*) AS total, - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) AS rate_limited + fmt.Sprintf(`SELECT ag.group_id, + COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total, + COUNT(*) FILTER (WHERE %s) AS active, + COUNT(*) FILTER (WHERE %s) AS rate_limited FROM account_groups ag JOIN accounts a ON a.id = ag.account_id WHERE ag.group_id = ANY($1) - GROUP BY ag.group_id`, + GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), pq.Array(groupIDs), ) if err != nil { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index c98d8861..68183b2b 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -714,9 +714,9 @@ func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") } -// TestListWithFilters_RateLimitedAccountCount 验证 RateLimitedAccountCount 正确统计临时受限账号。 -// 受限账号(rate_limit_reset_at 尚未过期)仍然计入 ActiveAccountCount, -// 同时额外出现在 RateLimitedAccountCount 中。 +// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。 +// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池, +// 因此 ActiveAccountCount 必须与真实调度查询口径一致。 func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { g := &service.Group{ Name: "g-rate-limited", @@ -740,6 +740,24 @@ func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { []any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth}, &rateLimitedID)) + var overloadedID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth}, + &overloadedID)) + + var tempUnschedulableID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth}, + &tempUnschedulableID)) + + var expiredID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id", + []any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth}, + &expiredID)) + _, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", normalID, g.ID, 1) @@ -748,6 +766,18 @@ func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", rateLimitedID, g.ID, 2) s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + overloadedID, g.ID, 3) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + tempUnschedulableID, g.ID, 4) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + expiredID, g.ID, 5) + s.Require().NoError(err) isExclusive := false groups, _, err := s.repo.ListWithFilters(s.ctx, @@ -763,10 +793,20 @@ func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { } } s.Require().NotNil(found, "created group must appear in ListWithFilters result") - s.Assert().Equal(int64(2), found.AccountCount, "AccountCount must be 2") - // rate-limited account is still active+schedulable, so it counts toward active - s.Assert().Equal(int64(2), found.ActiveAccountCount, "rate-limited account still counts as active") - s.Assert().Equal(int64(1), found.RateLimitedAccountCount, "RateLimitedAccountCount must be 1") + s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts") + s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts") + s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts") + + total, active, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount") + s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") + + detail, err := s.repo.GetByID(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters") + s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters") + s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters") } // --- DeleteAccountGroupsByGroupID ---