bindbox-game/internal/service/user/address_share.go
Zuncle bd91c0fad1 fix(transfer): 修复赠送资产并发漏洞及转赠积分薅取问题
- SubmitAddressShare 事务内 SELECT FOR UPDATE 锁定资产行,防止并发重复提交
- 检查 UPDATE RowsAffected,静默失败时回滚事务
- 防重检查从 readDB 移入事务内写库,消除主从延迟竞态
- RedeemInventoryToPoints/RedeemInventoriesToPoints 添加转赠来源校验,
  禁止通过转赠获得的资产兑换积分
2026-03-11 14:14:34 +08:00

824 lines
27 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package user
import (
"context"
"fmt"
"time"
"bindbox-game/configs"
"bindbox-game/internal/repository/mysql/model"
"bindbox-game/internal/pkg/wechat"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"gorm.io/gorm"
)
type shareClaims struct {
OwnerUserID int64 `json:"owner_user_id"`
InventoryID int64 `json:"inventory_id"`
jwt.RegisteredClaims
}
func signShareToken(ownerUserID int64, inventoryID int64, expiresAt time.Time) (string, error) {
claims := shareClaims{
OwnerUserID: ownerUserID,
InventoryID: inventoryID,
RegisteredClaims: jwt.RegisteredClaims{
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(configs.Get().Random.CommitMasterKey))
}
func parseShareToken(tokenString string) (*shareClaims, error) {
tokenClaims, err := jwt.ParseWithClaims(tokenString, &shareClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(configs.Get().Random.CommitMasterKey), nil
})
if tokenClaims != nil {
if claims, ok := tokenClaims.Claims.(*shareClaims); ok && tokenClaims.Valid {
return claims, nil
}
}
return nil, err
}
func (s *service) CreateAddressShare(ctx context.Context, userID int64, inventoryID int64, expiresAt time.Time) (string, string, time.Time, error) {
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return "", "", time.Time{}, err
}
token, err := signShareToken(userID, inventoryID, expiresAt)
if err != nil {
return "", "", time.Time{}, err
}
// 尝试生成微信小程序 ShortLink
shortLink := ""
c := configs.Get()
if c.Wechat.AppID != "" && c.Wechat.AppSecret != "" {
wcfg := &wechat.WechatConfig{AppID: c.Wechat.AppID, AppSecret: c.Wechat.AppSecret}
at, errat := wechat.GetAccessTokenWithContext(ctx, wcfg)
if errat == nil {
// BUG修复地址填写页在 pages-user 分包下,需添加 pages-user 前缀
pagePath := fmt.Sprintf("pages-user/address/submit?token=%s", token)
pageTitle := "送你一个好礼,快来填写地址领走吧!"
if inv.Remark != "" {
pageTitle = fmt.Sprintf("送你一个%s快来领走吧", inv.Remark)
}
sl, errsl := wechat.GetShortLink(at, pagePath, pageTitle)
if errsl == nil {
shortLink = sl
s.logger.Info("成功生成微信短链", zap.String("short_link", shortLink))
} else {
// 降级尝试生成 Scheme
s.logger.Warn("生成微信短链失败尝试降级为Scheme", zap.Error(errsl), zap.String("page_path", pagePath))
// 修正 pagePath 格式URL Scheme 需要 path 和 query 分离
// BUG修复地址填写页在 pages-user 分包下
schemePath := "pages-user/address/submit"
schemeQuery := fmt.Sprintf("token=%s", token)
scheme, errScheme := wechat.GenerateScheme(at, schemePath, schemeQuery, "release")
if errScheme == nil {
shortLink = scheme
s.logger.Info("成功生成微信Scheme", zap.String("scheme", scheme))
} else {
s.logger.Error("生成微信Scheme也失败", zap.Error(errScheme))
}
}
} else {
s.logger.Error("获取微信AccessToken失败", zap.Error(errat))
}
} else {
s.logger.Warn("微信配置缺失,跳过短链生成", zap.String("appid", c.Wechat.AppID))
}
return token, shortLink, expiresAt, nil
}
func (s *service) RevokeAddressShare(ctx context.Context, userID int64, inventoryID int64) error {
return nil
}
func (s *service) SubmitAddressShare(ctx context.Context, shareToken string, name string, mobile string, province string, city string, district string, address string, submittedByUserID *int64, submittedIP *string) (int64, error) {
claims, err := parseShareToken(shareToken)
if err != nil {
s.logger.Error("SubmitAddressShare: Token parse failed", zap.Error(err), zap.String("token_masked", shareToken[:10]+"..."))
return 0, fmt.Errorf("invalid_or_expired_token")
}
s.logger.Info("SubmitAddressShare: Processing", zap.Int64("invID", claims.InventoryID), zap.Int64("owner", claims.OwnerUserID))
// 1. 确定资产最终归属地 (实名转赠逻辑)
targetUserID := claims.OwnerUserID
isTransfer := false
if submittedByUserID != nil && *submittedByUserID > 0 && *submittedByUserID != claims.OwnerUserID {
targetUserID = *submittedByUserID
isTransfer = true
}
var addrID int64
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// a. 锁定资产行SELECT FOR UPDATE 防止并发转赠)
var inv model.UserInventory
lockResult := tx.Raw("SELECT * FROM user_inventory WHERE id = ? FOR UPDATE", claims.InventoryID).Scan(&inv)
if lockResult.Error != nil {
s.logger.Error("SubmitAddressShare: Lock inventory failed", zap.Int64("invID", claims.InventoryID), zap.Error(lockResult.Error))
return lockResult.Error
}
if inv.ID == 0 {
s.logger.Warn("SubmitAddressShare: Inventory not found", zap.Int64("invID", claims.InventoryID))
return fmt.Errorf("inventory_unavailable")
}
if inv.Status != 1 {
s.logger.Warn("SubmitAddressShare: Inventory unavailable", zap.Int64("invID", claims.InventoryID), zap.Int32("status", inv.Status))
return fmt.Errorf("inventory_unavailable")
}
// b. 在事务内检查发货记录(使用写库,避免主从延迟)
var shipCnt int64
if err := tx.Raw("SELECT COUNT(*) FROM shipping_records WHERE inventory_id = ? AND status != 5", claims.InventoryID).Scan(&shipCnt).Error; err != nil {
return err
}
if shipCnt > 0 {
s.logger.Warn("SubmitAddressShare: Already processed", zap.Int64("invID", claims.InventoryID))
return fmt.Errorf("already_processed")
}
// c. 创建收货地址 (归属于 targetUserID)
arow := &model.UserAddresses{
UserID: targetUserID,
Name: name,
Mobile: mobile,
Province: province,
City: city,
District: district,
Address: address,
IsDefault: 0,
}
// Check if user has a default address
cnt, _ := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(targetUserID), s.readDB.UserAddresses.IsDefault.Eq(1)).Count()
if cnt == 0 {
arow.IsDefault = 1
}
if err := tx.Omit("DefaultUserUnique").Create(arow).Error; err != nil {
return err
}
addrID = arow.ID
// d. 资产状态更新及所有权转移(检查 RowsAffected 防止并发写入)
if isTransfer {
// 记录转赠流水
transferLog := &model.UserInventoryTransfers{
InventoryID: claims.InventoryID,
FromUserID: claims.OwnerUserID,
ToUserID: targetUserID,
Remark: "address_share_transfer",
}
if err := tx.Create(transferLog).Error; err != nil {
return err
}
// 更新资产所属人
result := tx.Table("user_inventory").Where("id = ? AND user_id = ? AND status = 1", claims.InventoryID, claims.OwnerUserID).
Updates(map[string]interface{}{
"user_id": targetUserID,
"status": 3,
"updated_at": time.Now(),
"remark": fmt.Sprintf("transferred_from_%d|shipping_requested", claims.OwnerUserID),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("inventory_unavailable")
}
} else {
// 仅更新状态 (原主发货)
result := tx.Table("user_inventory").Where("id = ? AND user_id = ? AND status = 1", claims.InventoryID, claims.OwnerUserID).
Updates(map[string]interface{}{
"status": 3,
"updated_at": time.Now(),
"remark": "shipping_requested_via_share",
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("inventory_unavailable")
}
}
// e. 创建发货记录 (归属于 targetUserID)
// 使用资产价值快照,确保价格与分解时一致
price := inv.ValueCents
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格并记录快照
var p model.Products
if err := tx.Table("products").Where("id = ?", inv.ProductID).First(&p).Error; err == nil {
price = p.Price
// 回写资产价值快照
tx.Table("user_inventory").Where("id = ?", inv.ID).Updates(map[string]interface{}{
"value_cents": price,
"value_source": 2,
"value_snapshot_at": time.Now(),
})
}
}
// 生成转赠发货的批次号 (使用 T 前缀区分普通批发货的 B 前缀)
transferBatchNo := fmt.Sprintf("T%d%d", targetUserID, time.Now().UnixNano()/1000000)
shipRecord := &model.ShippingRecords{
UserID: targetUserID,
OrderID: inv.OrderID,
OrderItemID: 0,
InventoryID: claims.InventoryID,
ProductID: inv.ProductID,
Quantity: 1,
Price: price,
AddressID: addrID,
Status: 1,
BatchNo: transferBatchNo,
Remark: fmt.Sprintf("shared_address_submit|ip=%s|transfer_from=%d", *submittedIP, claims.OwnerUserID),
}
if err := tx.Omit("ShippedAt", "ReceivedAt").Create(shipRecord).Error; err != nil {
return err
}
return nil
})
if err != nil {
return 0, err
}
return addrID, nil
}
func (s *service) RequestShipping(ctx context.Context, userID int64, inventoryID int64) (int64, error) {
return s.RequestShippingWithBatch(ctx, userID, inventoryID, "", 0)
}
// RequestShippingWithBatch 申请发货(支持批次号和指定地址)
func (s *service) RequestShippingWithBatch(ctx context.Context, userID int64, inventoryID int64, batchNo string, addrID int64) (int64, error) {
cnt, err := s.readDB.ShippingRecords.WithContext(ctx).Where(
s.readDB.ShippingRecords.InventoryID.Eq(inventoryID),
s.readDB.ShippingRecords.Status.Neq(5), // Ignore cancelled
).Count()
if err == nil && cnt > 0 {
return 0, fmt.Errorf("already_processed")
}
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return 0, err
}
// 如果没有传入地址ID使用默认地址
if addrID <= 0 {
addr, err := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID), s.readDB.UserAddresses.IsDefault.Eq(1)).First()
if err != nil {
return 0, err
}
addrID = addr.ID
}
// 使用资产价值快照,确保价格与分解时一致
price := inv.ValueCents
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格并记录快照
if p, e := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.Eq(inv.ProductID)).First(); e == nil && p != nil {
price = p.Price
// 回写资产价值快照
s.repo.GetDbW().Exec("UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=NOW(3) WHERE id=? AND user_id=?",
price, 2, inventoryID, userID)
}
}
if db := s.repo.GetDbW().Exec("INSERT INTO shipping_records (user_id, order_id, order_item_id, inventory_id, product_id, quantity, price, address_id, status, batch_no, remark) VALUES (?,?,?,?,?,?,?,?,?,?,?)", userID, inv.OrderID, 0, inventoryID, inv.ProductID, 1, price, addrID, 1, batchNo, "user_request_shipping"); db.Error != nil {
err = db.Error
return 0, err
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|shipping_requested') WHERE id=? AND user_id=? AND status=1", inventoryID, userID); db.Error != nil {
err = db.Error
return 0, err
}
return addrID, nil
}
// generateBatchNo 生成唯一批次号
func generateBatchNo(userID int64) string {
return fmt.Sprintf("B%d%d", userID, time.Now().UnixNano()/1000000)
}
func (s *service) RequestShippings(ctx context.Context, userID int64, inventoryIDs []int64, addressID *int64) (addrID int64, batchNo string, success []int64, skipped []struct {
ID int64
Reason string
}, failed []struct {
ID int64
Reason string
}, err error) {
if len(inventoryIDs) == 0 {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "invalid_params"}}, nil
}
// 1. 去重
dedup := make(map[int64]struct{}, len(inventoryIDs))
uniq := make([]int64, 0, len(inventoryIDs))
for _, id := range inventoryIDs {
if id > 0 {
if _, ok := dedup[id]; !ok {
dedup[id] = struct{}{}
uniq = append(uniq, id)
}
}
}
if len(uniq) == 0 {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "invalid_params"}}, nil
}
// 2. 获取收货地址
if addressID != nil && *addressID > 0 {
ua, _ := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.ID.Eq(*addressID), s.readDB.UserAddresses.UserID.Eq(userID)).First()
if ua == nil {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "address_not_found"}}, nil
}
addrID = ua.ID
} else {
da, e := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID), s.readDB.UserAddresses.IsDefault.Eq(1)).First()
if e != nil || da == nil {
// 尝试查询用户所有地址
addrs, errFind := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID)).Find()
if errFind == nil && len(addrs) == 1 {
// 如果只有一个地址,自动设为默认
target := addrs[0]
if _, errUpd := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.ID.Eq(target.ID)).UpdateSimple(s.readDB.UserAddresses.IsDefault.Value(1)); errUpd == nil {
addrID = target.ID
} else {
s.logger.Error("Auto set default address failed", zap.Error(errUpd))
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "no_default_address"}}, nil
}
} else {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "no_default_address"}}, nil
}
} else {
addrID = da.ID
}
}
// 3. 生成批次号
batchNo = generateBatchNo(userID)
// 4. 批量查询所有inventory一次查询替代N次
invList, _ := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.In(uniq...)).
Find()
// 构建inventory映射
invMap := make(map[int64]*model.UserInventory, len(invList))
for _, inv := range invList {
invMap[inv.ID] = inv
}
// 5. 批量查询已有发货记录(检查哪些已处理)
var existingShipInvIDs []int64
_ = s.repo.GetDbR().Raw("SELECT DISTINCT inventory_id FROM shipping_records WHERE inventory_id IN ? AND status != 5", uniq).Scan(&existingShipInvIDs).Error
existingShipMap := make(map[int64]struct{}, len(existingShipInvIDs))
for _, id := range existingShipInvIDs {
existingShipMap[id] = struct{}{}
}
// 6. 分类处理
success = make([]int64, 0, len(uniq))
skipped = make([]struct {
ID int64
Reason string
}, 0)
failed = make([]struct {
ID int64
Reason string
}, 0)
validInvs := make([]*model.UserInventory, 0, len(uniq))
productIDs := make([]int64, 0, len(uniq))
productIDSet := make(map[int64]struct{})
for _, id := range uniq {
inv := invMap[id]
if inv == nil {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "not_found"})
continue
}
if inv.UserID != userID {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "not_owned"})
continue
}
if _, exists := existingShipMap[id]; exists {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "already_processed"})
continue
}
if inv.Status == 3 {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "already_requested"})
continue
}
if inv.Status != 1 {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "invalid_status"})
continue
}
validInvs = append(validInvs, inv)
if _, ok := productIDSet[inv.ProductID]; !ok && inv.ProductID > 0 {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
if len(validInvs) == 0 {
return addrID, batchNo, success, skipped, failed, nil
}
// 7. 批量查询products获取价格用于没有快照价格的资产
productMap := make(map[int64]int64) // productID -> price
// 收集需要回写快照的资产
type valueFix struct {
ID int64
ValueCents int64
ValueSource int32
}
valueFixes := make([]valueFix, 0)
// 先检查哪些资产没有快照价格
for _, inv := range validInvs {
if inv.ValueCents <= 0 && inv.ProductID > 0 {
if _, ok := productIDSet[inv.ProductID]; !ok {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
}
if len(productIDs) > 0 {
prods, _ := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.In(productIDs...)).Find()
for _, p := range prods {
productMap[p.ID] = p.Price
}
}
// 8. 单事务批量处理
validIDs := make([]int64, 0, len(validInvs))
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// 批量插入shipping_records使用资产价值快照
for _, inv := range validInvs {
// 优先使用资产价值快照,确保与分解价格一致
price := inv.ValueCents
valueSource := inv.ValueSource
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格
price = productMap[inv.ProductID]
valueSource = 2
// 记录需要回写快照的资产
valueFixes = append(valueFixes, valueFix{
ID: inv.ID,
ValueCents: price,
ValueSource: valueSource,
})
}
if errExec := tx.Exec(
"INSERT INTO shipping_records (user_id, order_id, order_item_id, inventory_id, product_id, quantity, price, address_id, status, batch_no, remark) VALUES (?,?,?,?,?,?,?,?,?,?,?)",
userID, inv.OrderID, 0, inv.ID, inv.ProductID, 1, price, addrID, 1, batchNo, "batch_request_shipping",
).Error; errExec != nil {
return errExec
}
validIDs = append(validIDs, inv.ID)
}
// 回写资产价值快照(用于之前没有快照的资产)
for _, fix := range valueFixes {
if err := tx.Exec(
"UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=NOW(3) WHERE id=? AND user_id=?",
fix.ValueCents, fix.ValueSource, fix.ID, userID,
).Error; err != nil {
return err
}
}
// 批量更新inventory状态一次UPDATE替代N次
if len(validIDs) > 0 {
if errExec := tx.Exec(
"UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|batch_shipping_requested') WHERE id IN ? AND user_id=? AND status=1",
validIDs, userID,
).Error; errExec != nil {
return errExec
}
}
return nil
})
if err != nil {
// 事务失败所有都标记为failed
for _, inv := range validInvs {
failed = append(failed, struct {
ID int64
Reason string
}{ID: inv.ID, Reason: err.Error()})
}
return addrID, batchNo, success, skipped, failed, nil
}
success = validIDs
return addrID, batchNo, success, skipped, failed, nil
}
func (s *service) RedeemInventoryToPoints(ctx context.Context, userID int64, inventoryID int64) (int64, error) {
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return 0, err
}
// 校验转赠来源:通过转赠获得的资产不允许兑换积分(防薅积分漏洞)
transferCnt, _ := s.readDB.UserInventoryTransfers.WithContext(ctx).Where(
s.readDB.UserInventoryTransfers.InventoryID.Eq(inventoryID),
s.readDB.UserInventoryTransfers.ToUserID.Eq(userID),
).Count()
if transferCnt > 0 {
return 0, fmt.Errorf("transfer_inventory_cannot_redeem")
}
valueCents := inv.ValueCents
valueSource := inv.ValueSource
valueSnapshotAt := inv.ValueSnapshotAt
if valueCents <= 0 {
p, err := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.Eq(inv.ProductID)).First()
if err != nil {
return 0, err
}
valueCents = p.Price
valueSource = 2
valueSnapshotAt = time.Now()
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=? WHERE id=? AND user_id=?", valueCents, valueSource, valueSnapshotAt, inventoryID, userID); db.Error != nil {
return 0, db.Error
}
}
cfg, _ := s.readDB.SystemConfigs.WithContext(ctx).Where(s.readDB.SystemConfigs.ConfigKey.Eq("points_exchange_per_cent")).First()
rate := int64(1)
if cfg != nil {
var r int64
_, _ = fmt.Sscanf(cfg.ConfigValue, "%d", &r)
if r > 0 {
rate = r
}
}
points := valueCents * rate
if err = s.AddPoints(ctx, userID, points, "redeem_reward", fmt.Sprintf("inventory:%d product:%d", inventoryID, inv.ProductID), nil, nil); err != nil {
return 0, err
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=3, remark=CONCAT(IFNULL(remark,''),'|redeemed_points=',?) WHERE id=? AND user_id=? AND status=1", points, inventoryID, userID); db.Error != nil {
err = db.Error
return 0, err
}
return points, nil
}
func (s *service) RedeemInventoriesToPoints(ctx context.Context, userID int64, inventoryIDs []int64) (int64, error) {
if len(inventoryIDs) == 0 {
return 0, fmt.Errorf("invalid_params")
}
// 1. 去重
dedup := make(map[int64]struct{})
uniq := make([]int64, 0, len(inventoryIDs))
for _, id := range inventoryIDs {
if id <= 0 {
continue
}
if _, ok := dedup[id]; !ok {
dedup[id] = struct{}{}
uniq = append(uniq, id)
}
}
if len(uniq) == 0 {
return 0, fmt.Errorf("invalid_params")
}
// 2. 获取兑换比率(只查询一次)
cfg, _ := s.readDB.SystemConfigs.WithContext(ctx).Where(s.readDB.SystemConfigs.ConfigKey.Eq("points_exchange_per_cent")).First()
rate := int64(1)
if cfg != nil {
var r int64
_, _ = fmt.Sscanf(cfg.ConfigValue, "%d", &r)
if r > 0 {
rate = r
}
}
// 3. 批量查询所有inventory一次查询替代N次
invList, err := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.In(uniq...)).
Where(s.readDB.UserInventory.UserID.Eq(userID)).
Where(s.readDB.UserInventory.Status.Eq(1)).
Find()
if err != nil {
return 0, err
}
if len(invList) == 0 {
return 0, fmt.Errorf("no_valid_inventory")
}
// 3.5 排除通过转赠获得的资产(防薅积分漏洞)
invIDs := make([]int64, 0, len(invList))
for _, inv := range invList {
invIDs = append(invIDs, inv.ID)
}
transferredInvs, _ := s.readDB.UserInventoryTransfers.WithContext(ctx).
Where(s.readDB.UserInventoryTransfers.InventoryID.In(invIDs...)).
Where(s.readDB.UserInventoryTransfers.ToUserID.Eq(userID)).
Find()
transferredSet := make(map[int64]struct{}, len(transferredInvs))
for _, t := range transferredInvs {
transferredSet[t.InventoryID] = struct{}{}
}
filteredInvList := make([]*model.UserInventory, 0, len(invList))
for _, inv := range invList {
if _, isTransferred := transferredSet[inv.ID]; !isTransferred {
filteredInvList = append(filteredInvList, inv)
}
}
if len(filteredInvList) == 0 {
return 0, fmt.Errorf("transfer_inventory_cannot_redeem")
}
invList = filteredInvList
// 4. 按资产快照计算总积分,缺失快照时回退商品价格并回写
productIDs := make([]int64, 0, len(invList))
productIDSet := make(map[int64]struct{})
for _, inv := range invList {
if inv.ValueCents <= 0 {
if _, ok := productIDSet[inv.ProductID]; !ok {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
}
productPriceMap := make(map[int64]int64)
if len(productIDs) > 0 {
products, err := s.readDB.Products.WithContext(ctx).
Where(s.readDB.Products.ID.In(productIDs...)).
Find()
if err != nil {
return 0, err
}
for _, p := range products {
productPriceMap[p.ID] = p.Price
}
}
// 5. 计算总积分和准备批量更新
var totalPoints int64
validIDs := make([]int64, 0, len(invList))
type valueFix struct {
ID int64
ValueCents int64
ValueSource int32
ValueSnapAt time.Time
}
valueFixes := make([]valueFix, 0)
for _, inv := range invList {
valueCents := inv.ValueCents
valueSource := inv.ValueSource
valueSnapshotAt := inv.ValueSnapshotAt
if valueCents <= 0 {
price, ok := productPriceMap[inv.ProductID]
if !ok {
continue
}
valueCents = price
valueSource = 2
valueSnapshotAt = time.Now()
valueFixes = append(valueFixes, valueFix{
ID: inv.ID,
ValueCents: valueCents,
ValueSource: valueSource,
ValueSnapAt: valueSnapshotAt,
})
}
if valueCents <= 0 {
continue
}
points := valueCents * rate
totalPoints += points
validIDs = append(validIDs, inv.ID)
}
if len(validIDs) == 0 {
return 0, fmt.Errorf("no_valid_products")
}
// 6. 单事务处理:添加积分 + 批量更新inventory状态
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// 添加积分(一次性添加总积分)
now := time.Now()
ledger := &model.UserPointsLedger{
UserID: userID,
Action: "redeem_reward",
Points: totalPoints,
RefTable: "user_inventory",
RefID: fmt.Sprintf("batch:%d", len(validIDs)),
Remark: fmt.Sprintf("batch_redeem_%d_items", len(validIDs)),
}
if err := tx.Create(ledger).Error; err != nil {
return err
}
// 更新积分余额
pointRecord := &model.UserPoints{
UserID: userID,
Kind: "redeem_reward",
Points: totalPoints,
ValidStart: now,
ValidEnd: now.AddDate(100, 0, 0), // 100年有效期
}
if err := tx.Create(pointRecord).Error; err != nil {
return err
}
// 批量更新inventory状态一次UPDATE替代N次
for _, fix := range valueFixes {
if err := tx.Exec(
"UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=? WHERE id=? AND user_id=?",
fix.ValueCents, fix.ValueSource, fix.ValueSnapAt, fix.ID, userID,
).Error; err != nil {
return err
}
}
if err := tx.Exec(
"UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|batch_redeemed') WHERE id IN ? AND user_id=? AND status=1",
validIDs, userID,
).Error; err != nil {
return err
}
return nil
})
if err != nil {
return 0, err
}
return totalPoints, nil
}
func (s *service) VoidUserInventory(ctx context.Context, adminID int64, userID int64, inventoryID int64) error {
if userID <= 0 || inventoryID <= 0 {
return fmt.Errorf("invalid_params")
}
inv, err := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.Eq(inventoryID)).
Where(s.readDB.UserInventory.UserID.Eq(userID)).
First()
if err != nil {
return err
}
if inv.Status != 1 {
return fmt.Errorf("invalid_status")
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=2, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|void_by_admin') WHERE id=? AND user_id=? AND status=1", inventoryID, userID); db.Error != nil {
return db.Error
}
_ = adminID
return nil
}