package taskcenter import ( "context" "encoding/json" "testing" "time" "bindbox-game/internal/repository/mysql" tcmodel "bindbox-game/internal/repository/mysql/task_center" "gorm.io/datatypes" "gorm.io/gorm" ) func ensureExtraTablesForServiceTest(t *testing.T, db *gorm.DB) { if !db.Migrator().HasTable("orders") { if err := db.Exec(`CREATE TABLE orders ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, status INTEGER NOT NULL DEFAULT 1, source_type INTEGER NOT NULL DEFAULT 0, total_amount INTEGER NOT NULL DEFAULT 0, actual_amount INTEGER NOT NULL DEFAULT 0, remark TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, deleted_at DATETIME );`).Error; err != nil { t.Fatalf("创建 orders 表失败: %v", err) } } if !db.Migrator().HasTable("activity_draw_logs") { if err := db.Exec(`CREATE TABLE activity_draw_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, order_id INTEGER NOT NULL, issue_id INTEGER NOT NULL );`).Error; err != nil { t.Fatalf("创建 activity_draw_logs 表失败: %v", err) } } if !db.Migrator().HasTable("activity_issues") { if err := db.Exec(`CREATE TABLE activity_issues ( id INTEGER PRIMARY KEY AUTOINCREMENT, activity_id INTEGER NOT NULL );`).Error; err != nil { t.Fatalf("创建 activity_issues 表失败: %v", err) } } if !db.Migrator().HasTable("activities") { if err := db.Exec(`CREATE TABLE activities ( id INTEGER PRIMARY KEY AUTOINCREMENT, price_draw INTEGER NOT NULL DEFAULT 0 );`).Error; err != nil { t.Fatalf("创建 activities 表失败: %v", err) } } if !db.Migrator().HasTable("user_invites") { if err := db.Exec(`CREATE TABLE user_invites ( id INTEGER PRIMARY KEY AUTOINCREMENT, inviter_id INTEGER NOT NULL, invitee_id INTEGER NOT NULL, accumulated_amount INTEGER NOT NULL DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, deleted_at DATETIME );`).Error; err != nil { t.Fatalf("创建 user_invites 表失败: %v", err) } } } func TestGetUserProgress_TimeWindow_Integration(t *testing.T) { repo, err := mysql.NewSQLiteRepoForTest() if err != nil { t.Fatalf("创建 repo 失败: %v", err) } db := repo.GetDbW() initTestTables(t, db) ensureExtraTablesForServiceTest(t, db) svc := New(nil, repo, nil, nil, nil) now := time.Now() taskStart := now.Add(-200 * 24 * time.Hour) taskEnd := now.Add(200 * 24 * time.Hour) // 创建一个具有任务有效期的任务 task := &tcmodel.Task{ Name: "时效性测试任务", Description: "测试各档位时效隔离", Status: 1, Visibility: 1, StartTime: &taskStart, EndTime: &taskEnd, } if err := db.Create(task).Error; err != nil { t.Fatalf("创建任务失败: %v", err) } db.Exec("INSERT INTO activities (id, price_draw) VALUES (1, 100)") db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (1, 1)") windows := []string{WindowDaily, WindowWeekly, WindowMonthly, WindowActivityPeriod, WindowLifetime} tierIDMap := make(map[string]int64) for _, w := range windows { tier := &tcmodel.TaskTier{ TaskID: task.ID, Metric: MetricOrderCount, Operator: OperatorGTE, Threshold: 1, Window: w, ActivityID: 0, } if err := db.Create(tier).Error; err != nil { t.Fatalf("创建档位失败: %v", err) } tierIDMap[w] = tier.ID } userID := int64(888) // 插入三笔订单与邀请,处于不同时间段 o1Time := now.Format(time.DateTime) db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (101, ?, 2, 0, 100, ?)", userID, o1Time) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (101, 1)") db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 901, ?)", userID, o1Time) o2Time := now.AddDate(0, -2, 0).Format(time.DateTime) db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (102, ?, 2, 0, 100, ?)", userID, o2Time) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (102, 1)") db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 902, ?)", userID, o2Time) o3Time := now.AddDate(-1, 0, 0).Format(time.DateTime) db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (103, ?, 2, 0, 100, ?)", userID, o3Time) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (103, 1)") db.Exec("INSERT INTO user_invites (inviter_id, invitee_id, created_at) VALUES (?, 903, ?)", userID, o3Time) // 调用统计 progress, err := svc.GetUserProgress(context.Background(), userID, task.ID) if err != nil { t.Fatalf("获取进度失败: %v", err) } // 验证各 Tier 的统计数据符合预期 for w, tid := range tierIDMap { tp, ok := progress.TierProgressMap[tid] if !ok { t.Errorf("缺少 %s 的进度", w) continue } var expectedCount int64 switch w { case WindowDaily, WindowWeekly, WindowMonthly: expectedCount = 1 case WindowActivityPeriod: expectedCount = 2 // O1, O2 case WindowLifetime: expectedCount = 3 // O1, O2, O3 } if tp.OrderCount != expectedCount { t.Errorf("[%s] OrderCount 不符: Expected %d, Got %d", w, expectedCount, tp.OrderCount) } else { t.Logf("[%s] OrderCount 验证成功: %d", w, tp.OrderCount) } if tp.InviteCount != expectedCount { t.Errorf("[%s] InviteCount 不符: Expected %d, Got %d", w, expectedCount, tp.InviteCount) } else { t.Logf("[%s] InviteCount 验证成功: %d", w, tp.InviteCount) } } } func TestUpsertTaskRewards_AllowsMultipleRewardsSameType(t *testing.T) { repo, err := mysql.NewSQLiteRepoForTest() if err != nil { t.Fatalf("创建 repo 失败: %v", err) } db := repo.GetDbW() initTestTables(t, db) svc := New(nil, repo, nil, nil, nil) task := &tcmodel.Task{Name: "奖励重入", Description: "测试奖励更新", Status: 1, Visibility: 1} if err := db.Create(task).Error; err != nil { t.Fatalf("创建任务失败: %v", err) } tier := &tcmodel.TaskTier{ TaskID: task.ID, Metric: MetricOrderCount, Operator: OperatorGTE, Threshold: 1, Window: WindowLifetime, } if err := db.Create(tier).Error; err != nil { t.Fatalf("创建档位失败: %v", err) } initialRewards := []TaskRewardInput{ {TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":1,"quantity":1}`)), Quantity: 1}, {TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: datatypes.JSON([]byte(`{"coupon_id":2,"quantity":1}`)), Quantity: 2}, } if err := svc.UpsertTaskRewards(context.Background(), task.ID, initialRewards, nil); err != nil { t.Fatalf("首次保存奖励失败: %v", err) } var stored []tcmodel.TaskReward if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&stored).Error; err != nil { t.Fatalf("查询奖励失败: %v", err) } if len(stored) != 2 { t.Fatalf("奖励数量不正确, 期望 2 实际 %d", len(stored)) } updatePayload := datatypes.JSON([]byte(`{"coupon_id":99,"quantity":3}`)) secondPayload := datatypes.JSON([]byte(`{"coupon_id":200,"quantity":1}`)) updateInput := []TaskRewardInput{ {ID: stored[0].ID, TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: updatePayload, Quantity: 5}, {TierID: tier.ID, RewardType: RewardTypeCoupon, RewardPayload: secondPayload, Quantity: 1}, } if err := svc.UpsertTaskRewards(context.Background(), task.ID, updateInput, []int64{stored[1].ID}); err != nil { t.Fatalf("更新奖励失败: %v", err) } var refreshed []tcmodel.TaskReward if err := db.Where("task_id = ?", task.ID).Order("id asc").Find(&refreshed).Error; err != nil { t.Fatalf("查询更新后奖励失败: %v", err) } if len(refreshed) != 2 { t.Fatalf("更新后奖励数量不正确, 期望 2 实际 %d", len(refreshed)) } if refreshed[0].ID != stored[0].ID { t.Fatalf("原有奖励记录未被更新") } var pl map[string]int64 if err := json.Unmarshal(refreshed[0].RewardPayload, &pl); err != nil { t.Fatalf("解析奖励 payload 失败: %v", err) } if pl["coupon_id"] != 99 { t.Errorf("奖励 payload 未更新, 期望 99 实际 %d", pl["coupon_id"]) } if refreshed[0].Quantity != 5 { t.Errorf("奖励数量未更新, 期望 5 实际 %d", refreshed[0].Quantity) } for _, r := range refreshed { if r.ID == stored[1].ID { t.Fatalf("待删除的奖励仍存在, id=%d", r.ID) } } } func TestGetUserProgress_UsesEffectiveAmount(t *testing.T) { repo, err := mysql.NewSQLiteRepoForTest() if err != nil { t.Fatalf("创建 repo 失败: %v", err) } db := repo.GetDbW() initTestTables(t, db) ensureExtraTablesForServiceTest(t, db) svc := New(nil, repo, nil, nil, nil) task := &tcmodel.Task{Name: "真实消费口径", Status: 1, Visibility: 1} if err := db.Create(task).Error; err != nil { t.Fatalf("创建任务失败: %v", err) } tier := &tcmodel.TaskTier{ TaskID: task.ID, Metric: MetricOrderAmount, Operator: OperatorGTE, Threshold: 1, Window: WindowLifetime, ActivityID: 201, } if err := db.Create(tier).Error; err != nil { t.Fatalf("创建档位失败: %v", err) } secondaryTier := &tcmodel.TaskTier{ TaskID: task.ID, Metric: MetricOrderAmount, Operator: OperatorGTE, Threshold: 1, Window: WindowLifetime, ActivityID: 202, } if err := db.Create(secondaryTier).Error; err != nil { t.Fatalf("创建第二个档位失败: %v", err) } db.Exec("INSERT INTO activities (id, price_draw) VALUES (201, 1000)") db.Exec("INSERT INTO activities (id, price_draw) VALUES (202, 0)") db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (301, 201)") db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (302, 202)") userID := int64(6001) now := time.Now() inside := now.Format(time.DateTime) // 次卡订单:total_amount=0,但 price_draw>0, draw_count=2 db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (401, ?, 2, 0, 0, ?)", userID, inside) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)") db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (401, 301)") // 现金订单:price_draw=0,需回退 total_amount db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (402, ?, 2, 0, 1500, ?)", userID, inside) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (402, 302)") progress, err := svc.GetUserProgress(context.Background(), userID, task.ID) if err != nil { t.Fatalf("获取进度失败: %v", err) } if progress.OrderAmount != 3500 { t.Fatalf("订单金额统计错误,期望 3500 实际 %d", progress.OrderAmount) } if progress.OrderCount != 2 { t.Fatalf("订单数量统计错误,期望 2 实际 %d", progress.OrderCount) } tierProgress, ok := progress.TierProgressMap[tier.ID] if !ok { t.Fatalf("未找到档位进度") } if tierProgress.OrderAmount != 2000 { t.Fatalf("档位金额错误,期望 2000 实际 %d", tierProgress.OrderAmount) } if tierProgress.OrderCount != 1 { t.Fatalf("档位订单数错误,期望 1 实际 %d", tierProgress.OrderCount) } } func TestTimeWindow_ActivityPeriod(t *testing.T) { repo, err := mysql.NewSQLiteRepoForTest() if err != nil { t.Fatalf("创建 repo 失败: %v", err) } db := repo.GetDbW() initTestTables(t, db) ensureExtraTablesForServiceTest(t, db) svc := New(nil, repo, nil, nil, nil) start := time.Now().AddDate(0, -1, 0) end := start.AddDate(0, 0, 10) task := &tcmodel.Task{ Name: "任务窗口期", Status: 1, Visibility: 1, StartTime: &start, EndTime: &end, } if err := db.Create(task).Error; err != nil { t.Fatalf("创建任务失败: %v", err) } tier := &tcmodel.TaskTier{ TaskID: task.ID, Metric: MetricOrderCount, Operator: OperatorGTE, Threshold: 1, Window: WindowActivityPeriod, ActivityID: 501, } if err := db.Create(tier).Error; err != nil { t.Fatalf("创建档位失败: %v", err) } db.Exec("INSERT INTO activities (id, price_draw) VALUES (501, 500)") db.Exec("INSERT INTO activity_issues (id, activity_id) VALUES (601, 501)") userID := int64(7007) inside := start.Add(24 * time.Hour).Format(time.DateTime) outside := end.Add(24 * time.Hour).Format(time.DateTime) db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (701, ?, 2, 0, 0, ?)", userID, inside) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (701, 601)") db.Exec("INSERT INTO orders (id, user_id, status, source_type, total_amount, created_at) VALUES (702, ?, 2, 0, 0, ?)", userID, outside) db.Exec("INSERT INTO activity_draw_logs (order_id, issue_id) VALUES (702, 601)") progress, err := svc.GetUserProgress(context.Background(), userID, task.ID) if err != nil { t.Fatalf("获取进度失败: %v", err) } tierProgress, ok := progress.TierProgressMap[tier.ID] if !ok { t.Fatalf("未找到活动有效期档位进度") } if tierProgress.OrderCount != 1 { t.Fatalf("活动有效期窗口统计错误,期望 1 实际 %d", tierProgress.OrderCount) } if progress.OrderCount != 2 { t.Fatalf("总体订单统计错误,期望 2 实际 %d", progress.OrderCount) } }