2026-02-27 00:08:02 +08:00

403 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package taskcenter
import (
"context"
"encoding/json"
"testing"
"time"
"bindbox-game/internal/repository/mysql"
tcmodel "bindbox-game/internal/repository/mysql/task_center"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func ensureExtraTablesForServiceTest(t *testing.T, db *gorm.DB) {
if !db.Migrator().HasTable("orders") {
if err := db.Exec(`CREATE TABLE orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
status INTEGER NOT NULL DEFAULT 1,
source_type INTEGER NOT NULL DEFAULT 0,
total_amount INTEGER NOT NULL DEFAULT 0,
actual_amount INTEGER NOT NULL DEFAULT 0,
remark TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
deleted_at DATETIME
);`).Error; err != nil {
t.Fatalf("创建 orders 表失败: %v", err)
}
}
if !db.Migrator().HasTable("activity_draw_logs") {
if err := db.Exec(`CREATE TABLE activity_draw_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
order_id INTEGER NOT NULL,
issue_id INTEGER NOT NULL
);`).Error; err != nil {
t.Fatalf("创建 activity_draw_logs 表失败: %v", err)
}
}
if !db.Migrator().HasTable("activity_issues") {
if err := db.Exec(`CREATE TABLE activity_issues (
id INTEGER PRIMARY KEY AUTOINCREMENT,
activity_id INTEGER NOT NULL
);`).Error; err != nil {
t.Fatalf("创建 activity_issues 表失败: %v", err)
}
}
if !db.Migrator().HasTable("activities") {
if err := db.Exec(`CREATE TABLE activities (
id INTEGER PRIMARY KEY AUTOINCREMENT,
price_draw INTEGER NOT NULL DEFAULT 0
);`).Error; err != nil {
t.Fatalf("创建 activities 表失败: %v", err)
}
}
if !db.Migrator().HasTable("user_invites") {
if err := db.Exec(`CREATE TABLE user_invites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
inviter_id INTEGER NOT NULL,
invitee_id INTEGER NOT NULL,
accumulated_amount INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
deleted_at DATETIME
);`).Error; err != nil {
t.Fatalf("创建 user_invites 表失败: %v", err)
}
}
}
func TestGetUserProgress_TimeWindow_Integration(t *testing.T) {
repo, err := mysql.NewSQLiteRepoForTest()
if err != nil {
t.Fatalf("创建 repo 失败: %v", err)
}
db := repo.GetDbW()
initTestTables(t, db)
ensureExtraTablesForServiceTest(t, db)
svc := New(nil, repo, nil, nil, nil)
now := time.Now()
taskStart := now.Add(-200 * 24 * time.Hour)
taskEnd := now.Add(200 * 24 * time.Hour)
// 创建一个具有任务有效期的任务
task := &tcmodel.Task{
Name: "时效性测试任务",
Description: "测试各档位时效隔离",
Status: 1,
Visibility: 1,
StartTime: &taskStart,
EndTime: &taskEnd,
}
if err := db.Create(task).Error; err != nil {
t.Fatalf("创建任务失败: %v", err)
}
db.Exec("INSERT INTO activities (id, price_draw) VALUES (1, 100)")
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (1, 1)")
windows := []string{WindowDaily, WindowWeekly, WindowMonthly, WindowActivityPeriod, WindowLifetime}
tierIDMap := make(map[string]int64)
for _, w := range windows {
tier := &tcmodel.TaskTier{
TaskID: task.ID,
Metric: MetricOrderCount,
Operator: OperatorGTE,
Threshold: 1,
Window: w,
ActivityID: 0,
}
if err := db.Create(tier).Error; err != nil {
t.Fatalf("创建档位失败: %v", err)
}
tierIDMap[w] = tier.ID
}
userID := int64(888)
// 插入三笔订单与邀请,处于不同时间段
o1Time := now.Format(time.DateTime)
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (101, ?, 2, 0, 100, ?)", userID, o1Time)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (101, 1)")
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 901, ?)", userID, o1Time)
o2Time := now.AddDate(0, -2, 0).Format(time.DateTime)
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (102, ?, 2, 0, 100, ?)", userID, o2Time)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (102, 1)")
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 902, ?)", userID, o2Time)
o3Time := now.AddDate(-1, 0, 0).Format(time.DateTime)
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (103, ?, 2, 0, 100, ?)", userID, o3Time)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (103, 1)")
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 903, ?)", userID, o3Time)
// 调用统计
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
if err != nil {
t.Fatalf("获取进度失败: %v", err)
}
// 验证各 Tier 的统计数据符合预期
for w, tid := range tierIDMap {
tp, ok := progress.TierProgressMap[tid]
if !ok {
t.Errorf("缺少 %s 的进度", w)
continue
}
var expectedCount int64
switch w {
case WindowDaily, WindowWeekly, WindowMonthly:
expectedCount = 1
case WindowActivityPeriod:
expectedCount = 2 // O1, O2
case WindowLifetime:
expectedCount = 3 // O1, O2, O3
}
if tp.OrderCount != expectedCount {
t.Errorf("[%s] OrderCount 不符: Expected %d, Got %d", w, expectedCount, tp.OrderCount)
} else {
t.Logf("[%s] OrderCount 验证成功: %d", w, tp.OrderCount)
}
if tp.InviteCount != expectedCount {
t.Errorf("[%s] InviteCount 不符: Expected %d, Got %d", w, expectedCount, tp.InviteCount)
} else {
t.Logf("[%s] InviteCount 验证成功: %d", w, tp.InviteCount)
}
}
}
func TestUpsertTaskRewards_AllowsMultipleRewardsSameType(t *testing.T) {
repo, err := mysql.NewSQLiteRepoForTest()
if err != nil {
t.Fatalf("创建 repo 失败: %v", err)
}
db := repo.GetDbW()
initTestTables(t, db)
svc := New(nil, repo, nil, nil, nil)
task := &tcmodel.Task{Name: "奖励重入", Description: "测试奖励更新", Status: 1, Visibility: 1}
if err := db.Create(task).Error; err != nil {
t.Fatalf("创建任务失败: %v", err)
}
tier := &tcmodel.TaskTier{
TaskID: task.ID,
Metric: MetricOrderCount,
Operator: OperatorGTE,
Threshold: 1,
Window: WindowLifetime,
}
if err := db.Create(tier).Error; err != nil {
t.Fatalf("创建档位失败: %v", err)
}
initialRewards := []TaskRewardInput{
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":1,"quantity":1}`)), Quantity: 1},
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":2,"quantity":1}`)), Quantity: 2},
}
if err := svc.UpsertTaskRewards(context.Background(), task.ID, initialRewards, nil); err != nil {
t.Fatalf("首次保存奖励失败: %v", err)
}
var stored []tcmodel.TaskReward
if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&stored).Error; err != nil {
t.Fatalf("查询奖励失败: %v", err)
}
if len(stored) != 2 {
t.Fatalf("奖励数量不正确, 期望 2 实际 %d", len(stored))
}
updatePayload := datatypes.JSON([]byte(`{"coupon_id":99,"quantity":3}`))
secondPayload := datatypes.JSON([]byte(`{"coupon_id":200,"quantity":1}`))
updateInput := []TaskRewardInput{
{ID: stored[0].ID, TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: updatePayload, Quantity: 5},
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: secondPayload, Quantity: 1},
}
if err := svc.UpsertTaskRewards(context.Background(), task.ID, updateInput, []int64{stored[1].ID}); err != nil {
t.Fatalf("更新奖励失败: %v", err)
}
var refreshed []tcmodel.TaskReward
if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&refreshed).Error; err != nil {
t.Fatalf("查询更新后奖励失败: %v", err)
}
if len(refreshed) != 2 {
t.Fatalf("更新后奖励数量不正确, 期望 2 实际 %d", len(refreshed))
}
if refreshed[0].ID != stored[0].ID {
t.Fatalf("原有奖励记录未被更新")
}
var pl map[string]int64
if err := json.Unmarshal(refreshed[0].RewardPayload, &pl); err != nil {
t.Fatalf("解析奖励 payload 失败: %v", err)
}
if pl["coupon_id"] != 99 {
t.Errorf("奖励 payload 未更新, 期望 99 实际 %d", pl["coupon_id"])
}
if refreshed[0].Quantity != 5 {
t.Errorf("奖励数量未更新, 期望 5 实际 %d", refreshed[0].Quantity)
}
for _, r := range refreshed {
if r.ID == stored[1].ID {
t.Fatalf("待删除的奖励仍存在, id=%d", r.ID)
}
}
}
func TestGetUserProgress_UsesEffectiveAmount(t *testing.T) {
repo, err := mysql.NewSQLiteRepoForTest()
if err != nil {
t.Fatalf("创建 repo 失败: %v", err)
}
db := repo.GetDbW()
initTestTables(t, db)
ensureExtraTablesForServiceTest(t, db)
svc := New(nil, repo, nil, nil, nil)
task := &tcmodel.Task{Name: "真实消费口径", Status: 1, Visibility: 1}
if err := db.Create(task).Error; err != nil {
t.Fatalf("创建任务失败: %v", err)
}
tier := &tcmodel.TaskTier{
TaskID: task.ID,
Metric: MetricOrderAmount,
Operator: OperatorGTE,
Threshold: 1,
Window: WindowLifetime,
ActivityID: 201,
}
if err := db.Create(tier).Error; err != nil {
t.Fatalf("创建档位失败: %v", err)
}
secondaryTier := &tcmodel.TaskTier{
TaskID: task.ID,
Metric: MetricOrderAmount,
Operator: OperatorGTE,
Threshold: 1,
Window: WindowLifetime,
ActivityID: 202,
}
if err := db.Create(secondaryTier).Error; err != nil {
t.Fatalf("创建第二个档位失败: %v", err)
}
db.Exec("INSERT INTO activities (id, price_draw) VALUES (201, 1000)")
db.Exec("INSERT INTO activities (id, price_draw) VALUES (202, 0)")
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (301, 201)")
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (302, 202)")
userID := int64(6001)
now := time.Now()
inside := now.Format(time.DateTime)
// 次卡订单total_amount=0但 price_draw>0, draw_count=2
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (401, ?, 2, 0, 0, ?)", userID, inside)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)")
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)")
// 现金订单price_draw=0需回退 total_amount
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (402, ?, 2, 0, 1500, ?)", userID, inside)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (402, 302)")
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
if err != nil {
t.Fatalf("获取进度失败: %v", err)
}
if progress.OrderAmount != 3500 {
t.Fatalf("订单金额统计错误,期望 3500 实际 %d", progress.OrderAmount)
}
if progress.OrderCount != 2 {
t.Fatalf("订单数量统计错误,期望 2 实际 %d", progress.OrderCount)
}
tierProgress, ok := progress.TierProgressMap[tier.ID]
if !ok {
t.Fatalf("未找到档位进度")
}
if tierProgress.OrderAmount != 2000 {
t.Fatalf("档位金额错误,期望 2000 实际 %d", tierProgress.OrderAmount)
}
if tierProgress.OrderCount != 1 {
t.Fatalf("档位订单数错误,期望 1 实际 %d", tierProgress.OrderCount)
}
}
func TestTimeWindow_ActivityPeriod(t *testing.T) {
repo, err := mysql.NewSQLiteRepoForTest()
if err != nil {
t.Fatalf("创建 repo 失败: %v", err)
}
db := repo.GetDbW()
initTestTables(t, db)
ensureExtraTablesForServiceTest(t, db)
svc := New(nil, repo, nil, nil, nil)
start := time.Now().AddDate(0, -1, 0)
end := start.AddDate(0, 0, 10)
task := &tcmodel.Task{
Name: "任务窗口期",
Status: 1,
Visibility: 1,
StartTime: &start,
EndTime: &end,
}
if err := db.Create(task).Error; err != nil {
t.Fatalf("创建任务失败: %v", err)
}
tier := &tcmodel.TaskTier{
TaskID: task.ID,
Metric: MetricOrderCount,
Operator: OperatorGTE,
Threshold: 1,
Window: WindowActivityPeriod,
ActivityID: 501,
}
if err := db.Create(tier).Error; err != nil {
t.Fatalf("创建档位失败: %v", err)
}
db.Exec("INSERT INTO activities (id, price_draw) VALUES (501, 500)")
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (601, 501)")
userID := int64(7007)
inside := start.Add(24 * time.Hour).Format(time.DateTime)
outside := end.Add(24 * time.Hour).Format(time.DateTime)
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (701, ?, 2, 0, 0, ?)", userID, inside)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (701, 601)")
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (702, ?, 2, 0, 0, ?)", userID, outside)
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (702, 601)")
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
if err != nil {
t.Fatalf("获取进度失败: %v", err)
}
tierProgress, ok := progress.TierProgressMap[tier.ID]
if !ok {
t.Fatalf("未找到活动有效期档位进度")
}
if tierProgress.OrderCount != 1 {
t.Fatalf("活动有效期窗口统计错误,期望 1 实际 %d", tierProgress.OrderCount)
}
if progress.OrderCount != 2 {
t.Fatalf("总体订单统计错误,期望 2 实际 %d", progress.OrderCount)
}
}