bindbox-game/internal/service/user/address_share.go
win 8229b41382 fix(security): 修复赠送资产薅积分三大漏洞
1. SELECT FOR UPDATE 锁定资产行,防止并发转赠竞态条件
2. 检查 RowsAffected 防止 GORM 静默失败导致空壳发货记录
3. 兑换积分时校验转赠来源,禁止转赠资产兑换积分
4. 转赠来源校验改用写库查询,避免主从延迟绕过
5. 转赠来源查询错误不再静默忽略,失败时返回错误

基于 zuncle 分支修复,额外修正了两个安全隐患:
- RedeemInventoryToPoints/RedeemInventoriesToPoints 中
  转赠记录查询从 readDB 改为 writeDB
- Count()/Find() 返回的 error 不再丢弃
2026-03-11 16:25:11 +08:00

826 lines
27 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package user
import (
"context"
"fmt"
"time"
"bindbox-game/configs"
"bindbox-game/internal/repository/mysql/model"
"bindbox-game/internal/pkg/wechat"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"gorm.io/gorm"
)
type shareClaims struct {
OwnerUserID int64 `json:"owner_user_id"`
InventoryID int64 `json:"inventory_id"`
jwt.RegisteredClaims
}
func signShareToken(ownerUserID int64, inventoryID int64, expiresAt time.Time) (string, error) {
claims := shareClaims{
OwnerUserID: ownerUserID,
InventoryID: inventoryID,
RegisteredClaims: jwt.RegisteredClaims{
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(configs.Get().Random.CommitMasterKey))
}
func parseShareToken(tokenString string) (*shareClaims, error) {
tokenClaims, err := jwt.ParseWithClaims(tokenString, &shareClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(configs.Get().Random.CommitMasterKey), nil
})
if tokenClaims != nil {
if claims, ok := tokenClaims.Claims.(*shareClaims); ok && tokenClaims.Valid {
return claims, nil
}
}
return nil, err
}
func (s *service) CreateAddressShare(ctx context.Context, userID int64, inventoryID int64, expiresAt time.Time) (string, string, time.Time, error) {
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return "", "", time.Time{}, err
}
token, err := signShareToken(userID, inventoryID, expiresAt)
if err != nil {
return "", "", time.Time{}, err
}
// 尝试生成微信小程序 ShortLink
shortLink := ""
c := configs.Get()
if c.Wechat.AppID != "" && c.Wechat.AppSecret != "" {
wcfg := &wechat.WechatConfig{AppID: c.Wechat.AppID, AppSecret: c.Wechat.AppSecret}
at, errat := wechat.GetAccessTokenWithContext(ctx, wcfg)
if errat == nil {
// BUG修复地址填写页在 pages-user 分包下,需添加 pages-user 前缀
pagePath := fmt.Sprintf("pages-user/address/submit?token=%s", token)
pageTitle := "送你一个好礼,快来填写地址领走吧!"
if inv.Remark != "" {
pageTitle = fmt.Sprintf("送你一个%s快来领走吧", inv.Remark)
}
sl, errsl := wechat.GetShortLink(at, pagePath, pageTitle)
if errsl == nil {
shortLink = sl
s.logger.Info("成功生成微信短链", zap.String("short_link", shortLink))
} else {
// 降级尝试生成 Scheme
s.logger.Warn("生成微信短链失败尝试降级为Scheme", zap.Error(errsl), zap.String("page_path", pagePath))
// 修正 pagePath 格式URL Scheme 需要 path 和 query 分离
// BUG修复地址填写页在 pages-user 分包下
schemePath := "pages-user/address/submit"
schemeQuery := fmt.Sprintf("token=%s", token)
scheme, errScheme := wechat.GenerateScheme(at, schemePath, schemeQuery, "release")
if errScheme == nil {
shortLink = scheme
s.logger.Info("成功生成微信Scheme", zap.String("scheme", scheme))
} else {
s.logger.Error("生成微信Scheme也失败", zap.Error(errScheme))
}
}
} else {
s.logger.Error("获取微信AccessToken失败", zap.Error(errat))
}
} else {
s.logger.Warn("微信配置缺失,跳过短链生成", zap.String("appid", c.Wechat.AppID))
}
return token, shortLink, expiresAt, nil
}
func (s *service) RevokeAddressShare(ctx context.Context, userID int64, inventoryID int64) error {
return nil
}
func (s *service) SubmitAddressShare(ctx context.Context, shareToken string, name string, mobile string, province string, city string, district string, address string, submittedByUserID *int64, submittedIP *string) (int64, error) {
claims, err := parseShareToken(shareToken)
if err != nil {
s.logger.Error("SubmitAddressShare: Token parse failed", zap.Error(err), zap.String("token_masked", shareToken[:10]+"..."))
return 0, fmt.Errorf("invalid_or_expired_token")
}
s.logger.Info("SubmitAddressShare: Processing", zap.Int64("invID", claims.InventoryID), zap.Int64("owner", claims.OwnerUserID))
// 1. 确定资产最终归属地 (实名转赠逻辑)
targetUserID := claims.OwnerUserID
isTransfer := false
if submittedByUserID != nil && *submittedByUserID > 0 && *submittedByUserID != claims.OwnerUserID {
targetUserID = *submittedByUserID
isTransfer = true
}
var addrID int64
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// a. 锁定资产行SELECT FOR UPDATE 防止并发转赠)
var inv model.UserInventory
lockResult := tx.Raw("SELECT * FROM user_inventory WHERE id = ? FOR UPDATE", claims.InventoryID).Scan(&inv)
if lockResult.Error != nil {
s.logger.Error("SubmitAddressShare: Lock inventory failed", zap.Int64("invID", claims.InventoryID), zap.Error(lockResult.Error))
return lockResult.Error
}
if inv.ID == 0 {
s.logger.Warn("SubmitAddressShare: Inventory not found", zap.Int64("invID", claims.InventoryID))
return fmt.Errorf("inventory_unavailable")
}
if inv.Status != 1 {
s.logger.Warn("SubmitAddressShare: Inventory unavailable", zap.Int64("invID", claims.InventoryID), zap.Int32("status", inv.Status))
return fmt.Errorf("inventory_unavailable")
}
// b. 在事务内检查发货记录(使用写库,避免主从延迟)
var shipCnt int64
if err := tx.Raw("SELECT COUNT(*) FROM shipping_records WHERE inventory_id = ? AND status != 5", claims.InventoryID).Scan(&shipCnt).Error; err != nil {
return err
}
if shipCnt > 0 {
s.logger.Warn("SubmitAddressShare: Already processed", zap.Int64("invID", claims.InventoryID))
return fmt.Errorf("already_processed")
}
// c. 创建收货地址 (归属于 targetUserID)
arow := &model.UserAddresses{
UserID: targetUserID,
Name: name,
Mobile: mobile,
Province: province,
City: city,
District: district,
Address: address,
IsDefault: 0,
}
// Check if user has a default address
cnt, _ := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(targetUserID), s.readDB.UserAddresses.IsDefault.Eq(1)).Count()
if cnt == 0 {
arow.IsDefault = 1
}
if err := tx.Omit("DefaultUserUnique").Create(arow).Error; err != nil {
return err
}
addrID = arow.ID
// d. 资产状态更新及所有权转移(检查 RowsAffected 防止并发写入)
if isTransfer {
// 记录转赠流水
transferLog := &model.UserInventoryTransfers{
InventoryID: claims.InventoryID,
FromUserID: claims.OwnerUserID,
ToUserID: targetUserID,
Remark: "address_share_transfer",
}
if err := tx.Create(transferLog).Error; err != nil {
return err
}
// 更新资产所属人
result := tx.Table("user_inventory").Where("id = ? AND user_id = ? AND status = 1", claims.InventoryID, claims.OwnerUserID).
Updates(map[string]interface{}{
"user_id": targetUserID,
"status": 3,
"updated_at": time.Now(),
"remark": fmt.Sprintf("transferred_from_%d|shipping_requested", claims.OwnerUserID),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("inventory_unavailable")
}
} else {
// 仅更新状态 (原主发货)
result := tx.Table("user_inventory").Where("id = ? AND user_id = ? AND status = 1", claims.InventoryID, claims.OwnerUserID).
Updates(map[string]interface{}{
"status": 3,
"updated_at": time.Now(),
"remark": "shipping_requested_via_share",
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("inventory_unavailable")
}
}
// e. 创建发货记录 (归属于 targetUserID)
// 使用资产价值快照,确保价格与分解时一致
price := inv.ValueCents
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格并记录快照
var p model.Products
if err := tx.Table("products").Where("id = ?", inv.ProductID).First(&p).Error; err == nil {
price = p.Price
// 回写资产价值快照
tx.Table("user_inventory").Where("id = ?", inv.ID).Updates(map[string]interface{}{
"value_cents": price,
"value_source": 2,
"value_snapshot_at": time.Now(),
})
}
}
// 生成转赠发货的批次号 (使用 T 前缀区分普通批发货的 B 前缀)
transferBatchNo := fmt.Sprintf("T%d%d", targetUserID, time.Now().UnixNano()/1000000)
shipRecord := &model.ShippingRecords{
UserID: targetUserID,
OrderID: inv.OrderID,
OrderItemID: 0,
InventoryID: claims.InventoryID,
ProductID: inv.ProductID,
Quantity: 1,
Price: price,
AddressID: addrID,
Status: 1,
BatchNo: transferBatchNo,
Remark: fmt.Sprintf("shared_address_submit|ip=%s|transfer_from=%d", *submittedIP, claims.OwnerUserID),
}
if err := tx.Omit("ShippedAt", "ReceivedAt").Create(shipRecord).Error; err != nil {
return err
}
return nil
})
if err != nil {
return 0, err
}
return addrID, nil
}
func (s *service) RequestShipping(ctx context.Context, userID int64, inventoryID int64) (int64, error) {
return s.RequestShippingWithBatch(ctx, userID, inventoryID, "", 0)
}
// RequestShippingWithBatch 申请发货(支持批次号和指定地址)
func (s *service) RequestShippingWithBatch(ctx context.Context, userID int64, inventoryID int64, batchNo string, addrID int64) (int64, error) {
cnt, err := s.readDB.ShippingRecords.WithContext(ctx).Where(
s.readDB.ShippingRecords.InventoryID.Eq(inventoryID),
s.readDB.ShippingRecords.Status.Neq(5), // Ignore cancelled
).Count()
if err == nil && cnt > 0 {
return 0, fmt.Errorf("already_processed")
}
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return 0, err
}
// 如果没有传入地址ID使用默认地址
if addrID <= 0 {
addr, err := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID), s.readDB.UserAddresses.IsDefault.Eq(1)).First()
if err != nil {
return 0, err
}
addrID = addr.ID
}
// 使用资产价值快照,确保价格与分解时一致
price := inv.ValueCents
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格并记录快照
if p, e := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.Eq(inv.ProductID)).First(); e == nil && p != nil {
price = p.Price
// 回写资产价值快照
s.repo.GetDbW().Exec("UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=NOW(3) WHERE id=? AND user_id=?",
price, 2, inventoryID, userID)
}
}
if db := s.repo.GetDbW().Exec("INSERT INTO shipping_records (user_id, order_id, order_item_id, inventory_id, product_id, quantity, price, address_id, status, batch_no, remark) VALUES (?,?,?,?,?,?,?,?,?,?,?)", userID, inv.OrderID, 0, inventoryID, inv.ProductID, 1, price, addrID, 1, batchNo, "user_request_shipping"); db.Error != nil {
err = db.Error
return 0, err
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|shipping_requested') WHERE id=? AND user_id=? AND status=1", inventoryID, userID); db.Error != nil {
err = db.Error
return 0, err
}
return addrID, nil
}
// generateBatchNo 生成唯一批次号
func generateBatchNo(userID int64) string {
return fmt.Sprintf("B%d%d", userID, time.Now().UnixNano()/1000000)
}
func (s *service) RequestShippings(ctx context.Context, userID int64, inventoryIDs []int64, addressID *int64) (addrID int64, batchNo string, success []int64, skipped []struct {
ID int64
Reason string
}, failed []struct {
ID int64
Reason string
}, err error) {
if len(inventoryIDs) == 0 {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "invalid_params"}}, nil
}
// 1. 去重
dedup := make(map[int64]struct{}, len(inventoryIDs))
uniq := make([]int64, 0, len(inventoryIDs))
for _, id := range inventoryIDs {
if id > 0 {
if _, ok := dedup[id]; !ok {
dedup[id] = struct{}{}
uniq = append(uniq, id)
}
}
}
if len(uniq) == 0 {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "invalid_params"}}, nil
}
// 2. 获取收货地址
if addressID != nil && *addressID > 0 {
ua, _ := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.ID.Eq(*addressID), s.readDB.UserAddresses.UserID.Eq(userID)).First()
if ua == nil {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "address_not_found"}}, nil
}
addrID = ua.ID
} else {
da, e := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID), s.readDB.UserAddresses.IsDefault.Eq(1)).First()
if e != nil || da == nil {
// 尝试查询用户所有地址
addrs, errFind := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.UserID.Eq(userID)).Find()
if errFind == nil && len(addrs) == 1 {
// 如果只有一个地址,自动设为默认
target := addrs[0]
if _, errUpd := s.readDB.UserAddresses.WithContext(ctx).Where(s.readDB.UserAddresses.ID.Eq(target.ID)).UpdateSimple(s.readDB.UserAddresses.IsDefault.Value(1)); errUpd == nil {
addrID = target.ID
} else {
s.logger.Error("Auto set default address failed", zap.Error(errUpd))
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "no_default_address"}}, nil
}
} else {
return 0, "", nil, nil, []struct {
ID int64
Reason string
}{{ID: 0, Reason: "no_default_address"}}, nil
}
} else {
addrID = da.ID
}
}
// 3. 生成批次号
batchNo = generateBatchNo(userID)
// 4. 批量查询所有inventory一次查询替代N次
invList, _ := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.In(uniq...)).
Find()
// 构建inventory映射
invMap := make(map[int64]*model.UserInventory, len(invList))
for _, inv := range invList {
invMap[inv.ID] = inv
}
// 5. 批量查询已有发货记录(检查哪些已处理)
var existingShipInvIDs []int64
_ = s.repo.GetDbR().Raw("SELECT DISTINCT inventory_id FROM shipping_records WHERE inventory_id IN ? AND status != 5", uniq).Scan(&existingShipInvIDs).Error
existingShipMap := make(map[int64]struct{}, len(existingShipInvIDs))
for _, id := range existingShipInvIDs {
existingShipMap[id] = struct{}{}
}
// 6. 分类处理
success = make([]int64, 0, len(uniq))
skipped = make([]struct {
ID int64
Reason string
}, 0)
failed = make([]struct {
ID int64
Reason string
}, 0)
validInvs := make([]*model.UserInventory, 0, len(uniq))
productIDs := make([]int64, 0, len(uniq))
productIDSet := make(map[int64]struct{})
for _, id := range uniq {
inv := invMap[id]
if inv == nil {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "not_found"})
continue
}
if inv.UserID != userID {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "not_owned"})
continue
}
if _, exists := existingShipMap[id]; exists {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "already_processed"})
continue
}
if inv.Status == 3 {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "already_requested"})
continue
}
if inv.Status != 1 {
skipped = append(skipped, struct {
ID int64
Reason string
}{ID: id, Reason: "invalid_status"})
continue
}
validInvs = append(validInvs, inv)
if _, ok := productIDSet[inv.ProductID]; !ok && inv.ProductID > 0 {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
if len(validInvs) == 0 {
return addrID, batchNo, success, skipped, failed, nil
}
// 7. 批量查询products获取价格用于没有快照价格的资产
productMap := make(map[int64]int64) // productID -> price
// 收集需要回写快照的资产
type valueFix struct {
ID int64
ValueCents int64
ValueSource int32
}
valueFixes := make([]valueFix, 0)
// 先检查哪些资产没有快照价格
for _, inv := range validInvs {
if inv.ValueCents <= 0 && inv.ProductID > 0 {
if _, ok := productIDSet[inv.ProductID]; !ok {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
}
if len(productIDs) > 0 {
prods, _ := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.In(productIDs...)).Find()
for _, p := range prods {
productMap[p.ID] = p.Price
}
}
// 8. 单事务批量处理
validIDs := make([]int64, 0, len(validInvs))
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// 批量插入shipping_records使用资产价值快照
for _, inv := range validInvs {
// 优先使用资产价值快照,确保与分解价格一致
price := inv.ValueCents
valueSource := inv.ValueSource
if price <= 0 && inv.ProductID > 0 {
// 如果没有快照价格,回退到商品当前价格
price = productMap[inv.ProductID]
valueSource = 2
// 记录需要回写快照的资产
valueFixes = append(valueFixes, valueFix{
ID: inv.ID,
ValueCents: price,
ValueSource: valueSource,
})
}
if errExec := tx.Exec(
"INSERT INTO shipping_records (user_id, order_id, order_item_id, inventory_id, product_id, quantity, price, address_id, status, batch_no, remark) VALUES (?,?,?,?,?,?,?,?,?,?,?)",
userID, inv.OrderID, 0, inv.ID, inv.ProductID, 1, price, addrID, 1, batchNo, "batch_request_shipping",
).Error; errExec != nil {
return errExec
}
validIDs = append(validIDs, inv.ID)
}
// 回写资产价值快照(用于之前没有快照的资产)
for _, fix := range valueFixes {
if err := tx.Exec(
"UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=NOW(3) WHERE id=? AND user_id=?",
fix.ValueCents, fix.ValueSource, fix.ID, userID,
).Error; err != nil {
return err
}
}
// 批量更新inventory状态一次UPDATE替代N次
if len(validIDs) > 0 {
if errExec := tx.Exec(
"UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|batch_shipping_requested') WHERE id IN ? AND user_id=? AND status=1",
validIDs, userID,
).Error; errExec != nil {
return errExec
}
}
return nil
})
if err != nil {
// 事务失败所有都标记为failed
for _, inv := range validInvs {
failed = append(failed, struct {
ID int64
Reason string
}{ID: inv.ID, Reason: err.Error()})
}
return addrID, batchNo, success, skipped, failed, nil
}
success = validIDs
return addrID, batchNo, success, skipped, failed, nil
}
func (s *service) RedeemInventoryToPoints(ctx context.Context, userID int64, inventoryID int64) (int64, error) {
inv, err := s.readDB.UserInventory.WithContext(ctx).Where(s.readDB.UserInventory.UserID.Eq(userID), s.readDB.UserInventory.ID.Eq(inventoryID), s.readDB.UserInventory.Status.Eq(1)).First()
if err != nil {
return 0, err
}
// 校验转赠来源:通过转赠获得的资产不允许兑换积分(防薅积分漏洞)
// 使用写库查询,避免主从延迟导致校验被绕过
var transferCnt int64
if err := s.repo.GetDbW().Raw("SELECT COUNT(*) FROM user_inventory_transfers WHERE inventory_id = ? AND to_user_id = ?", inventoryID, userID).Scan(&transferCnt).Error; err != nil {
return 0, err
}
if transferCnt > 0 {
return 0, fmt.Errorf("transfer_inventory_cannot_redeem")
}
valueCents := inv.ValueCents
valueSource := inv.ValueSource
valueSnapshotAt := inv.ValueSnapshotAt
if valueCents <= 0 {
p, err := s.readDB.Products.WithContext(ctx).Where(s.readDB.Products.ID.Eq(inv.ProductID)).First()
if err != nil {
return 0, err
}
valueCents = p.Price
valueSource = 2
valueSnapshotAt = time.Now()
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=? WHERE id=? AND user_id=?", valueCents, valueSource, valueSnapshotAt, inventoryID, userID); db.Error != nil {
return 0, db.Error
}
}
cfg, _ := s.readDB.SystemConfigs.WithContext(ctx).Where(s.readDB.SystemConfigs.ConfigKey.Eq("points_exchange_per_cent")).First()
rate := int64(1)
if cfg != nil {
var r int64
_, _ = fmt.Sscanf(cfg.ConfigValue, "%d", &r)
if r > 0 {
rate = r
}
}
points := valueCents * rate
if err = s.AddPoints(ctx, userID, points, "redeem_reward", fmt.Sprintf("inventory:%d product:%d", inventoryID, inv.ProductID), nil, nil); err != nil {
return 0, err
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=3, remark=CONCAT(IFNULL(remark,''),'|redeemed_points=',?) WHERE id=? AND user_id=? AND status=1", points, inventoryID, userID); db.Error != nil {
err = db.Error
return 0, err
}
return points, nil
}
func (s *service) RedeemInventoriesToPoints(ctx context.Context, userID int64, inventoryIDs []int64) (int64, error) {
if len(inventoryIDs) == 0 {
return 0, fmt.Errorf("invalid_params")
}
// 1. 去重
dedup := make(map[int64]struct{})
uniq := make([]int64, 0, len(inventoryIDs))
for _, id := range inventoryIDs {
if id <= 0 {
continue
}
if _, ok := dedup[id]; !ok {
dedup[id] = struct{}{}
uniq = append(uniq, id)
}
}
if len(uniq) == 0 {
return 0, fmt.Errorf("invalid_params")
}
// 2. 获取兑换比率(只查询一次)
cfg, _ := s.readDB.SystemConfigs.WithContext(ctx).Where(s.readDB.SystemConfigs.ConfigKey.Eq("points_exchange_per_cent")).First()
rate := int64(1)
if cfg != nil {
var r int64
_, _ = fmt.Sscanf(cfg.ConfigValue, "%d", &r)
if r > 0 {
rate = r
}
}
// 3. 批量查询所有inventory一次查询替代N次
invList, err := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.In(uniq...)).
Where(s.readDB.UserInventory.UserID.Eq(userID)).
Where(s.readDB.UserInventory.Status.Eq(1)).
Find()
if err != nil {
return 0, err
}
if len(invList) == 0 {
return 0, fmt.Errorf("no_valid_inventory")
}
// 3.5 排除通过转赠获得的资产(防薅积分漏洞)
// 使用写库查询,避免主从延迟导致校验被绕过
invIDs := make([]int64, 0, len(invList))
for _, inv := range invList {
invIDs = append(invIDs, inv.ID)
}
var transferredInvs []*model.UserInventoryTransfers
if err := s.repo.GetDbW().Raw("SELECT * FROM user_inventory_transfers WHERE inventory_id IN ? AND to_user_id = ?", invIDs, userID).Scan(&transferredInvs).Error; err != nil {
return 0, err
}
transferredSet := make(map[int64]struct{}, len(transferredInvs))
for _, t := range transferredInvs {
transferredSet[t.InventoryID] = struct{}{}
}
filteredInvList := make([]*model.UserInventory, 0, len(invList))
for _, inv := range invList {
if _, isTransferred := transferredSet[inv.ID]; !isTransferred {
filteredInvList = append(filteredInvList, inv)
}
}
if len(filteredInvList) == 0 {
return 0, fmt.Errorf("transfer_inventory_cannot_redeem")
}
invList = filteredInvList
// 4. 按资产快照计算总积分,缺失快照时回退商品价格并回写
productIDs := make([]int64, 0, len(invList))
productIDSet := make(map[int64]struct{})
for _, inv := range invList {
if inv.ValueCents <= 0 {
if _, ok := productIDSet[inv.ProductID]; !ok {
productIDSet[inv.ProductID] = struct{}{}
productIDs = append(productIDs, inv.ProductID)
}
}
}
productPriceMap := make(map[int64]int64)
if len(productIDs) > 0 {
products, err := s.readDB.Products.WithContext(ctx).
Where(s.readDB.Products.ID.In(productIDs...)).
Find()
if err != nil {
return 0, err
}
for _, p := range products {
productPriceMap[p.ID] = p.Price
}
}
// 5. 计算总积分和准备批量更新
var totalPoints int64
validIDs := make([]int64, 0, len(invList))
type valueFix struct {
ID int64
ValueCents int64
ValueSource int32
ValueSnapAt time.Time
}
valueFixes := make([]valueFix, 0)
for _, inv := range invList {
valueCents := inv.ValueCents
valueSource := inv.ValueSource
valueSnapshotAt := inv.ValueSnapshotAt
if valueCents <= 0 {
price, ok := productPriceMap[inv.ProductID]
if !ok {
continue
}
valueCents = price
valueSource = 2
valueSnapshotAt = time.Now()
valueFixes = append(valueFixes, valueFix{
ID: inv.ID,
ValueCents: valueCents,
ValueSource: valueSource,
ValueSnapAt: valueSnapshotAt,
})
}
if valueCents <= 0 {
continue
}
points := valueCents * rate
totalPoints += points
validIDs = append(validIDs, inv.ID)
}
if len(validIDs) == 0 {
return 0, fmt.Errorf("no_valid_products")
}
// 6. 单事务处理:添加积分 + 批量更新inventory状态
err = s.repo.GetDbW().Transaction(func(tx *gorm.DB) error {
// 添加积分(一次性添加总积分)
now := time.Now()
ledger := &model.UserPointsLedger{
UserID: userID,
Action: "redeem_reward",
Points: totalPoints,
RefTable: "user_inventory",
RefID: fmt.Sprintf("batch:%d", len(validIDs)),
Remark: fmt.Sprintf("batch_redeem_%d_items", len(validIDs)),
}
if err := tx.Create(ledger).Error; err != nil {
return err
}
// 更新积分余额
pointRecord := &model.UserPoints{
UserID: userID,
Kind: "redeem_reward",
Points: totalPoints,
ValidStart: now,
ValidEnd: now.AddDate(100, 0, 0), // 100年有效期
}
if err := tx.Create(pointRecord).Error; err != nil {
return err
}
// 批量更新inventory状态一次UPDATE替代N次
for _, fix := range valueFixes {
if err := tx.Exec(
"UPDATE user_inventory SET value_cents=?, value_source=?, value_snapshot_at=? WHERE id=? AND user_id=?",
fix.ValueCents, fix.ValueSource, fix.ValueSnapAt, fix.ID, userID,
).Error; err != nil {
return err
}
}
if err := tx.Exec(
"UPDATE user_inventory SET status=3, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|batch_redeemed') WHERE id IN ? AND user_id=? AND status=1",
validIDs, userID,
).Error; err != nil {
return err
}
return nil
})
if err != nil {
return 0, err
}
return totalPoints, nil
}
func (s *service) VoidUserInventory(ctx context.Context, adminID int64, userID int64, inventoryID int64) error {
if userID <= 0 || inventoryID <= 0 {
return fmt.Errorf("invalid_params")
}
inv, err := s.readDB.UserInventory.WithContext(ctx).
Where(s.readDB.UserInventory.ID.Eq(inventoryID)).
Where(s.readDB.UserInventory.UserID.Eq(userID)).
First()
if err != nil {
return err
}
if inv.Status != 1 {
return fmt.Errorf("invalid_status")
}
if db := s.repo.GetDbW().Exec("UPDATE user_inventory SET status=2, updated_at=NOW(3), remark=CONCAT(IFNULL(remark,''),'|void_by_admin') WHERE id=? AND user_id=? AND status=1", inventoryID, userID); db.Error != nil {
return db.Error
}
_ = adminID
return nil
}