403 lines
13 KiB
Go
403 lines
13 KiB
Go
package taskcenter
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"testing"
|
||
"time"
|
||
|
||
"bindbox-game/internal/repository/mysql"
|
||
tcmodel "bindbox-game/internal/repository/mysql/task_center"
|
||
|
||
"gorm.io/datatypes"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
func ensureExtraTablesForServiceTest(t *testing.T, db *gorm.DB) {
|
||
if !db.Migrator().HasTable("orders") {
|
||
if err := db.Exec(`CREATE TABLE orders (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
user_id INTEGER NOT NULL,
|
||
status INTEGER NOT NULL DEFAULT 1,
|
||
source_type INTEGER NOT NULL DEFAULT 0,
|
||
total_amount INTEGER NOT NULL DEFAULT 0,
|
||
actual_amount INTEGER NOT NULL DEFAULT 0,
|
||
remark TEXT,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
deleted_at DATETIME
|
||
);`).Error; err != nil {
|
||
t.Fatalf("创建 orders 表失败: %v", err)
|
||
}
|
||
}
|
||
if !db.Migrator().HasTable("activity_draw_logs") {
|
||
if err := db.Exec(`CREATE TABLE activity_draw_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
order_id INTEGER NOT NULL,
|
||
issue_id INTEGER NOT NULL
|
||
);`).Error; err != nil {
|
||
t.Fatalf("创建 activity_draw_logs 表失败: %v", err)
|
||
}
|
||
}
|
||
if !db.Migrator().HasTable("activity_issues") {
|
||
if err := db.Exec(`CREATE TABLE activity_issues (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
activity_id INTEGER NOT NULL
|
||
);`).Error; err != nil {
|
||
t.Fatalf("创建 activity_issues 表失败: %v", err)
|
||
}
|
||
}
|
||
if !db.Migrator().HasTable("activities") {
|
||
if err := db.Exec(`CREATE TABLE activities (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
price_draw INTEGER NOT NULL DEFAULT 0
|
||
);`).Error; err != nil {
|
||
t.Fatalf("创建 activities 表失败: %v", err)
|
||
}
|
||
}
|
||
if !db.Migrator().HasTable("user_invites") {
|
||
if err := db.Exec(`CREATE TABLE user_invites (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
inviter_id INTEGER NOT NULL,
|
||
invitee_id INTEGER NOT NULL,
|
||
accumulated_amount INTEGER NOT NULL DEFAULT 0,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||
deleted_at DATETIME
|
||
);`).Error; err != nil {
|
||
t.Fatalf("创建 user_invites 表失败: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGetUserProgress_TimeWindow_Integration(t *testing.T) {
|
||
repo, err := mysql.NewSQLiteRepoForTest()
|
||
if err != nil {
|
||
t.Fatalf("创建 repo 失败: %v", err)
|
||
}
|
||
db := repo.GetDbW()
|
||
|
||
initTestTables(t, db)
|
||
ensureExtraTablesForServiceTest(t, db)
|
||
|
||
svc := New(nil, repo, nil, nil, nil)
|
||
|
||
now := time.Now()
|
||
taskStart := now.Add(-200 * 24 * time.Hour)
|
||
taskEnd := now.Add(200 * 24 * time.Hour)
|
||
|
||
// 创建一个具有任务有效期的任务
|
||
task := &tcmodel.Task{
|
||
Name: "时效性测试任务",
|
||
Description: "测试各档位时效隔离",
|
||
Status: 1,
|
||
Visibility: 1,
|
||
StartTime: &taskStart,
|
||
EndTime: &taskEnd,
|
||
}
|
||
if err := db.Create(task).Error; err != nil {
|
||
t.Fatalf("创建任务失败: %v", err)
|
||
}
|
||
|
||
db.Exec("INSERT INTO activities (id, price_draw) VALUES (1, 100)")
|
||
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (1, 1)")
|
||
|
||
windows := []string{WindowDaily, WindowWeekly, WindowMonthly, WindowActivityPeriod, WindowLifetime}
|
||
|
||
tierIDMap := make(map[string]int64)
|
||
for _, w := range windows {
|
||
tier := &tcmodel.TaskTier{
|
||
TaskID: task.ID,
|
||
Metric: MetricOrderCount,
|
||
Operator: OperatorGTE,
|
||
Threshold: 1,
|
||
Window: w,
|
||
ActivityID: 0,
|
||
}
|
||
if err := db.Create(tier).Error; err != nil {
|
||
t.Fatalf("创建档位失败: %v", err)
|
||
}
|
||
tierIDMap[w] = tier.ID
|
||
}
|
||
|
||
userID := int64(888)
|
||
|
||
// 插入三笔订单与邀请,处于不同时间段
|
||
o1Time := now.Format(time.DateTime)
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (101, ?, 2, 0, 100, ?)", userID, o1Time)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (101, 1)")
|
||
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 901, ?)", userID, o1Time)
|
||
|
||
o2Time := now.AddDate(0, -2, 0).Format(time.DateTime)
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (102, ?, 2, 0, 100, ?)", userID, o2Time)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (102, 1)")
|
||
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 902, ?)", userID, o2Time)
|
||
|
||
o3Time := now.AddDate(-1, 0, 0).Format(time.DateTime)
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (103, ?, 2, 0, 100, ?)", userID, o3Time)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (103, 1)")
|
||
db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 903, ?)", userID, o3Time)
|
||
|
||
// 调用统计
|
||
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
|
||
if err != nil {
|
||
t.Fatalf("获取进度失败: %v", err)
|
||
}
|
||
|
||
// 验证各 Tier 的统计数据符合预期
|
||
for w, tid := range tierIDMap {
|
||
tp, ok := progress.TierProgressMap[tid]
|
||
if !ok {
|
||
t.Errorf("缺少 %s 的进度", w)
|
||
continue
|
||
}
|
||
|
||
var expectedCount int64
|
||
switch w {
|
||
case WindowDaily, WindowWeekly, WindowMonthly:
|
||
expectedCount = 1
|
||
case WindowActivityPeriod:
|
||
expectedCount = 2 // O1, O2
|
||
case WindowLifetime:
|
||
expectedCount = 3 // O1, O2, O3
|
||
}
|
||
|
||
if tp.OrderCount != expectedCount {
|
||
t.Errorf("[%s] OrderCount 不符: Expected %d, Got %d", w, expectedCount, tp.OrderCount)
|
||
} else {
|
||
t.Logf("[%s] OrderCount 验证成功: %d", w, tp.OrderCount)
|
||
}
|
||
|
||
if tp.InviteCount != expectedCount {
|
||
t.Errorf("[%s] InviteCount 不符: Expected %d, Got %d", w, expectedCount, tp.InviteCount)
|
||
} else {
|
||
t.Logf("[%s] InviteCount 验证成功: %d", w, tp.InviteCount)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestUpsertTaskRewards_AllowsMultipleRewardsSameType(t *testing.T) {
|
||
repo, err := mysql.NewSQLiteRepoForTest()
|
||
if err != nil {
|
||
t.Fatalf("创建 repo 失败: %v", err)
|
||
}
|
||
db := repo.GetDbW()
|
||
initTestTables(t, db)
|
||
|
||
svc := New(nil, repo, nil, nil, nil)
|
||
|
||
task := &tcmodel.Task{Name: "奖励重入", Description: "测试奖励更新", Status: 1, Visibility: 1}
|
||
if err := db.Create(task).Error; err != nil {
|
||
t.Fatalf("创建任务失败: %v", err)
|
||
}
|
||
tier := &tcmodel.TaskTier{
|
||
TaskID: task.ID,
|
||
Metric: MetricOrderCount,
|
||
Operator: OperatorGTE,
|
||
Threshold: 1,
|
||
Window: WindowLifetime,
|
||
}
|
||
if err := db.Create(tier).Error; err != nil {
|
||
t.Fatalf("创建档位失败: %v", err)
|
||
}
|
||
|
||
initialRewards := []TaskRewardInput{
|
||
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":1,"quantity":1}`)), Quantity: 1},
|
||
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":2,"quantity":1}`)), Quantity: 2},
|
||
}
|
||
if err := svc.UpsertTaskRewards(context.Background(), task.ID, initialRewards, nil); err != nil {
|
||
t.Fatalf("首次保存奖励失败: %v", err)
|
||
}
|
||
|
||
var stored []tcmodel.TaskReward
|
||
if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&stored).Error; err != nil {
|
||
t.Fatalf("查询奖励失败: %v", err)
|
||
}
|
||
if len(stored) != 2 {
|
||
t.Fatalf("奖励数量不正确, 期望 2 实际 %d", len(stored))
|
||
}
|
||
|
||
updatePayload := datatypes.JSON([]byte(`{"coupon_id":99,"quantity":3}`))
|
||
secondPayload := datatypes.JSON([]byte(`{"coupon_id":200,"quantity":1}`))
|
||
updateInput := []TaskRewardInput{
|
||
{ID: stored[0].ID, TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: updatePayload, Quantity: 5},
|
||
{TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: secondPayload, Quantity: 1},
|
||
}
|
||
if err := svc.UpsertTaskRewards(context.Background(), task.ID, updateInput, []int64{stored[1].ID}); err != nil {
|
||
t.Fatalf("更新奖励失败: %v", err)
|
||
}
|
||
|
||
var refreshed []tcmodel.TaskReward
|
||
if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&refreshed).Error; err != nil {
|
||
t.Fatalf("查询更新后奖励失败: %v", err)
|
||
}
|
||
if len(refreshed) != 2 {
|
||
t.Fatalf("更新后奖励数量不正确, 期望 2 实际 %d", len(refreshed))
|
||
}
|
||
if refreshed[0].ID != stored[0].ID {
|
||
t.Fatalf("原有奖励记录未被更新")
|
||
}
|
||
var pl map[string]int64
|
||
if err := json.Unmarshal(refreshed[0].RewardPayload, &pl); err != nil {
|
||
t.Fatalf("解析奖励 payload 失败: %v", err)
|
||
}
|
||
if pl["coupon_id"] != 99 {
|
||
t.Errorf("奖励 payload 未更新, 期望 99 实际 %d", pl["coupon_id"])
|
||
}
|
||
if refreshed[0].Quantity != 5 {
|
||
t.Errorf("奖励数量未更新, 期望 5 实际 %d", refreshed[0].Quantity)
|
||
}
|
||
for _, r := range refreshed {
|
||
if r.ID == stored[1].ID {
|
||
t.Fatalf("待删除的奖励仍存在, id=%d", r.ID)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGetUserProgress_UsesEffectiveAmount(t *testing.T) {
|
||
repo, err := mysql.NewSQLiteRepoForTest()
|
||
if err != nil {
|
||
t.Fatalf("创建 repo 失败: %v", err)
|
||
}
|
||
db := repo.GetDbW()
|
||
initTestTables(t, db)
|
||
ensureExtraTablesForServiceTest(t, db)
|
||
|
||
svc := New(nil, repo, nil, nil, nil)
|
||
|
||
task := &tcmodel.Task{Name: "真实消费口径", Status: 1, Visibility: 1}
|
||
if err := db.Create(task).Error; err != nil {
|
||
t.Fatalf("创建任务失败: %v", err)
|
||
}
|
||
|
||
tier := &tcmodel.TaskTier{
|
||
TaskID: task.ID,
|
||
Metric: MetricOrderAmount,
|
||
Operator: OperatorGTE,
|
||
Threshold: 1,
|
||
Window: WindowLifetime,
|
||
ActivityID: 201,
|
||
}
|
||
if err := db.Create(tier).Error; err != nil {
|
||
t.Fatalf("创建档位失败: %v", err)
|
||
}
|
||
|
||
secondaryTier := &tcmodel.TaskTier{
|
||
TaskID: task.ID,
|
||
Metric: MetricOrderAmount,
|
||
Operator: OperatorGTE,
|
||
Threshold: 1,
|
||
Window: WindowLifetime,
|
||
ActivityID: 202,
|
||
}
|
||
if err := db.Create(secondaryTier).Error; err != nil {
|
||
t.Fatalf("创建第二个档位失败: %v", err)
|
||
}
|
||
|
||
db.Exec("INSERT INTO activities (id, price_draw) VALUES (201, 1000)")
|
||
db.Exec("INSERT INTO activities (id, price_draw) VALUES (202, 0)")
|
||
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (301, 201)")
|
||
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (302, 202)")
|
||
|
||
userID := int64(6001)
|
||
now := time.Now()
|
||
inside := now.Format(time.DateTime)
|
||
|
||
// 次卡订单:total_amount=0,但 price_draw>0, draw_count=2
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (401, ?, 2, 0, 0, ?)", userID, inside)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)")
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)")
|
||
|
||
// 现金订单:price_draw=0,需回退 total_amount
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (402, ?, 2, 0, 1500, ?)", userID, inside)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (402, 302)")
|
||
|
||
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
|
||
if err != nil {
|
||
t.Fatalf("获取进度失败: %v", err)
|
||
}
|
||
|
||
if progress.OrderAmount != 3500 {
|
||
t.Fatalf("订单金额统计错误,期望 3500 实际 %d", progress.OrderAmount)
|
||
}
|
||
if progress.OrderCount != 2 {
|
||
t.Fatalf("订单数量统计错误,期望 2 实际 %d", progress.OrderCount)
|
||
}
|
||
tierProgress, ok := progress.TierProgressMap[tier.ID]
|
||
if !ok {
|
||
t.Fatalf("未找到档位进度")
|
||
}
|
||
if tierProgress.OrderAmount != 2000 {
|
||
t.Fatalf("档位金额错误,期望 2000 实际 %d", tierProgress.OrderAmount)
|
||
}
|
||
if tierProgress.OrderCount != 1 {
|
||
t.Fatalf("档位订单数错误,期望 1 实际 %d", tierProgress.OrderCount)
|
||
}
|
||
}
|
||
|
||
func TestTimeWindow_ActivityPeriod(t *testing.T) {
|
||
repo, err := mysql.NewSQLiteRepoForTest()
|
||
if err != nil {
|
||
t.Fatalf("创建 repo 失败: %v", err)
|
||
}
|
||
db := repo.GetDbW()
|
||
initTestTables(t, db)
|
||
ensureExtraTablesForServiceTest(t, db)
|
||
|
||
svc := New(nil, repo, nil, nil, nil)
|
||
|
||
start := time.Now().AddDate(0, -1, 0)
|
||
end := start.AddDate(0, 0, 10)
|
||
task := &tcmodel.Task{
|
||
Name: "任务窗口期",
|
||
Status: 1,
|
||
Visibility: 1,
|
||
StartTime: &start,
|
||
EndTime: &end,
|
||
}
|
||
if err := db.Create(task).Error; err != nil {
|
||
t.Fatalf("创建任务失败: %v", err)
|
||
}
|
||
|
||
tier := &tcmodel.TaskTier{
|
||
TaskID: task.ID,
|
||
Metric: MetricOrderCount,
|
||
Operator: OperatorGTE,
|
||
Threshold: 1,
|
||
Window: WindowActivityPeriod,
|
||
ActivityID: 501,
|
||
}
|
||
if err := db.Create(tier).Error; err != nil {
|
||
t.Fatalf("创建档位失败: %v", err)
|
||
}
|
||
|
||
db.Exec("INSERT INTO activities (id, price_draw) VALUES (501, 500)")
|
||
db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (601, 501)")
|
||
|
||
userID := int64(7007)
|
||
inside := start.Add(24 * time.Hour).Format(time.DateTime)
|
||
outside := end.Add(24 * time.Hour).Format(time.DateTime)
|
||
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (701, ?, 2, 0, 0, ?)", userID, inside)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (701, 601)")
|
||
|
||
db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (702, ?, 2, 0, 0, ?)", userID, outside)
|
||
db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (702, 601)")
|
||
|
||
progress, err := svc.GetUserProgress(context.Background(), userID, task.ID)
|
||
if err != nil {
|
||
t.Fatalf("获取进度失败: %v", err)
|
||
}
|
||
|
||
tierProgress, ok := progress.TierProgressMap[tier.ID]
|
||
if !ok {
|
||
t.Fatalf("未找到活动有效期档位进度")
|
||
}
|
||
if tierProgress.OrderCount != 1 {
|
||
t.Fatalf("活动有效期窗口统计错误,期望 1 实际 %d", tierProgress.OrderCount)
|
||
}
|
||
if progress.OrderCount != 2 {
|
||
t.Fatalf("总体订单统计错误,期望 2 实际 %d", progress.OrderCount)
|
||
}
|
||
}
|