package activity import ( "bindbox-game/internal/pkg/core" "bindbox-game/internal/pkg/logger" "bindbox-game/internal/repository/mysql" "bindbox-game/internal/repository/mysql/dao" "bindbox-game/internal/repository/mysql/model" "context" "mime/multipart" "sync" "testing" "time" drivermysql "gorm.io/driver/mysql" "gorm.io/gorm" ) // Mock Context type mockContext struct { core.Context // Embed interface to satisfy compiler ctx context.Context } func (m *mockContext) RequestContext() core.StdContext { return core.StdContext{Context: m.ctx} } // Satisfy compiler for embedded interface methods if needed by runtime checks (unlikely for this test path) func (m *mockContext) ShouldBindQuery(obj interface{}) error { return nil } func (m *mockContext) ShouldBindPostForm(obj interface{}) error { return nil } func (m *mockContext) ShouldBindForm(obj interface{}) error { return nil } func (m *mockContext) ShouldBindJSON(obj interface{}) error { return nil } func (m *mockContext) ShouldBindXML(obj interface{}) error { return nil } func (m *mockContext) ShouldBindURI(obj interface{}) error { return nil } func (m *mockContext) Redirect(code int, location string) {} func (m *mockContext) Trace() core.Trace { return nil } func (m *mockContext) setTrace(trace core.Trace) {} func (m *mockContext) disableTrace() {} func (m *mockContext) Logger() logger.CustomLogger { return nil } func (m *mockContext) setLogger(logger logger.CustomLogger) {} func (m *mockContext) Payload(payload interface{}) {} func (m *mockContext) getPayload() interface{} { return nil } func (m *mockContext) File(filePath string) {} func (m *mockContext) HTML(name string, obj interface{}) {} func (m *mockContext) String(str string) {} func (m *mockContext) XML(obj interface{}) {} func (m *mockContext) ExcelData(filename string, byteData []byte) {} func (m *mockContext) FormFile(name string) (*multipart.FileHeader, error) { return nil, nil } func (m *mockContext) SaveUploadedFile(file *multipart.FileHeader, dst string) error { return nil } func (m *mockContext) AbortWithError(err core.BusinessError) {} func (m *mockContext) abortError() core.BusinessError { return nil } // Header signature mismatch was an issue. core.Context: Header() http.Header. // We need imports if we want to implement it fully or just use "core.Context" embedding trick. // Since we embed, we don't *need* to implement them unless called. // But some might be called by dependencies we don't see. // TestConcurrencyCoupon verifies that concurrent concurrent orders do not over-deduct coupon balance. // NOTE: This test requires a real DB connection. func TestConcurrencyCoupon(t *testing.T) { // 1. Setup DB dsn := "root:bindbox2025kdy@tcp(150.158.78.154:3306)/dev_game?charset=utf8mb4&parseTime=true&loc=Local" db, err := gorm.Open(drivermysql.Open(dsn), &gorm.Config{}) if err != nil { t.Skipf("Skipping test due to DB connection failure: %v", err) return } // Initialize dao generic dao.Use(db) // Use helper from mysql package to get a valid Repo implementation repo := mysql.NewTestRepo(db) log, _ := logger.NewCustomLogger(nil, logger.WithOutputInConsole()) // 2. Init Service svc := NewActivityOrderService(log, repo) // 3. Prepare Data userID := int64(99999) // Test User // Find a valid system coupon (Type 1 - Amount) var sysCoupon model.SystemCoupons if err := db.Where("discount_type = 1 AND status = 1").First(&sysCoupon).Error; err != nil { t.Skipf("No valid system coupon found (Type 1), skipping: %v", err) return } // Create User Coupon userCoupon := model.UserCoupons{ UserID: userID, CouponID: sysCoupon.ID, Status: 1, BalanceAmount: 5000, // 50 yuan ValidStart: time.Now().Add(-1 * time.Hour), ValidEnd: time.Now().Add(24 * time.Hour), CreatedAt: time.Now(), } if err := db.Omit("UsedAt", "UsedOrderID").Create(&userCoupon).Error; err != nil { t.Fatalf("Failed to create user coupon: %v", err) } t.Logf("Created test coupon ID: %d with balance 5000", userCoupon.ID) // 4. Concurrency Test var wg sync.WaitGroup successCount := 0 failCount := 0 var mu sync.Mutex concurrency := 20 for i := 0; i < concurrency; i++ { wg.Add(1) go func(idx int) { defer wg.Done() req := CreateActivityOrderRequest{ UserID: userID, ActivityID: 1, // Dummy IssueID: 1, // Dummy Count: 1, UnitPrice: 1000, // 10 yuan SourceType: 2, CouponID: &userCoupon.ID, } // Mock Context mockCtx := &mockContext{ctx: context.Background()} res, err := svc.CreateActivityOrder(mockCtx, req) mu.Lock() defer mu.Unlock() if err != nil { failCount++ } else { t.Logf("[%d] Success. Discount: %d", idx, res.AppliedCouponVal) if res.AppliedCouponVal > 0 { successCount++ } } }(i) } wg.Wait() // 5. Verify Result var finalCoupon model.UserCoupons err = db.First(&finalCoupon, userCoupon.ID).Error if err != nil { t.Fatalf("Failed to query final coupon: %v", err) } t.Logf("Final Balance: %d. Success Orders with Discount: %d. Failures: %d", finalCoupon.BalanceAmount, successCount, failCount) if finalCoupon.BalanceAmount < 0 { t.Errorf("Balance is negative: %d", finalCoupon.BalanceAmount) } // Verify total deducted var orderCoupons []model.OrderCoupons db.Where("user_coupon_id = ?", userCoupon.ID).Find(&orderCoupons) totalDeducted := int64(0) for _, oc := range orderCoupons { totalDeducted += oc.AppliedAmount } t.Logf("Total Deducted from OrderCoupons table: %d", totalDeducted) expectedDeduction := 5000 - finalCoupon.BalanceAmount if expectedDeduction != totalDeducted { t.Errorf("Mismatch! Initial-Final(%d) != OrderCoupons Sum(%d)", expectedDeduction, totalDeducted) } }