Merge branch 'main' into fix/openai-ws-passthrough-reasoning-effort

This commit is contained in:
deqiying 2026-05-03 22:18:46 +08:00
commit 11fe29223d
41 changed files with 3226 additions and 243 deletions

View File

@ -2,8 +2,11 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
} }
response.Success(c, result) response.Success(c, result)
} }
// GetUserOverview returns one user's affiliate overview.
// GET /api/v1/admin/affiliates/users/:user_id/overview
func (h *AffiliateHandler) GetUserOverview(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
if err != nil || userID <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, overview)
}
// ListInviteRecords returns all inviter-invitee relationships.
// GET /api/v1/admin/affiliates/invites
func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListRebateRecords returns all order-level affiliate rebate records.
// GET /api/v1/admin/affiliates/rebates
func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
// ListTransferRecords returns all affiliate quota-to-balance transfer records.
// GET /api/v1/admin/affiliates/transfers
func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := parseAffiliateRecordFilter(c, page, pageSize)
items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, total, filter.Page, filter.PageSize)
}
func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter {
filter := service.AffiliateRecordFilter{
Search: c.Query("search"),
Page: page,
PageSize: pageSize,
SortBy: c.Query("sort_by"),
SortDesc: c.Query("sort_order") != "asc",
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
userTZ := c.Query("timezone")
if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil {
filter.StartAt = t
}
if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil {
filter.EndAt = t
}
return filter
}
func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
return &parsed
}
return nil
}
func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
return &parsed
}
if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil {
end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond)
return &end
}
return nil
}

View File

@ -390,7 +390,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
// GetBalanceHistory handles getting user's balance/concurrency change history // GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history // GET /api/v1/admin/users/:id/balance-history
// Query params: // Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) // - type: filter by record type (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) { func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64) userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {

View File

@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type) assert.Equal(t, "message_stop", events[1].Type)
} }
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
assert.Equal(t, 12, events[0].Usage.InputTokens)
assert.Equal(t, 4, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{

View File

@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
} }
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState() state := NewResponsesEventToChatState()
state.Model = "gpt-4o" state.Model = "gpt-4o"

View File

@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
return resToAnthHandleReasoningDelta(evt, state) return resToAnthHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return resToAnthHandleBlockDone(state) return resToAnthHandleBlockDone(state)
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToAnthHandleCompleted(evt, state) return resToAnthHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleReasoningDelta(evt, state) return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return nil return nil
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state) return resToChatHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
type ResponsesStreamEvent struct { type ResponsesStreamEvent struct {
Type string `json:"type"` Type string `json:"type"`
// response.created / response.completed / response.failed / response.incomplete // response.created / response.completed / response.done / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"` Response *ResponsesResponse `json:"response,omitempty"`
// response.output_item.added / response.output_item.done // response.output_item.added / response.output_item.done

View File

@ -22,6 +22,34 @@ const (
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
const affiliateUserOverviewSQL = `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.aff_code,
COALESCE(ua.aff_rebate_rate_percent, 0)::double precision,
(ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate,
ua.aff_count,
COALESCE(rebated.rebated_invitee_count, 0),
(ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision,
ua.aff_history_quota::double precision
FROM user_affiliates ua
JOIN users u ON u.id = ua.user_id
LEFT JOIN (
SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count
FROM user_affiliate_ledger
WHERE action = 'accrue' AND source_user_id IS NOT NULL
GROUP BY user_id
) rebated ON rebated.user_id = ua.user_id
LEFT JOIN (
SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota
FROM user_affiliate_ledger
WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW()
GROUP BY user_id
) matured ON matured.user_id = ua.user_id
WHERE ua.user_id = $1
LIMIT 1`
type affiliateQueryExecer interface { type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
@ -86,7 +114,7 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil return bound, nil
} }
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) {
if amount <= 0 { if amount <= 0 {
return false, nil return false, nil
} }
@ -112,15 +140,15 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
if freezeHours > 0 { if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, VALUES ($1, 'accrue', $2, $3, $4, NOW() + make_interval(hours => $5), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil { inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID), freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} else { } else {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, source_order_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { VALUES ($1, 'accrue', $2, $3, $4, NOW(), NOW())`, inviterID, amount, inviteeUserID, nullableInt64Arg(sourceOrderID)); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
} }
@ -275,9 +303,32 @@ FROM cleared`, userID)
return err return err
} }
snapshot, err := queryAffiliateTransferSnapshot(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { user_id,
action,
amount,
source_user_id,
balance_after,
aff_quota_after,
aff_frozen_quota_after,
aff_history_quota_after,
created_at,
updated_at
)
VALUES ($1, 'transfer', $2, NULL, $3, $4, $5, $6, NOW(), NOW())`,
userID,
transferred,
snapshot.BalanceAfter,
snapshot.AvailableQuotaAfter,
snapshot.FrozenQuotaAfter,
snapshot.HistoryQuotaAfter,
); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err) return fmt.Errorf("insert affiliate transfer ledger: %w", err)
} }
@ -332,6 +383,349 @@ LIMIT $2`, inviterID, limit)
return invitees, nil return invitees, nil
} }
func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code",
})
total, err := queryAffiliateRecordCount(ctx, client, `
SELECT COUNT(*)
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
`+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"inviter": "inviter.email",
"invitee": "invitee.email",
"aff_code": "inviter_aff.aff_code",
"total_rebate": "total_rebate",
"created_at": "ua.created_at",
}, "ua.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ua.inviter_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ua.user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
COALESCE(inviter_aff.aff_code, ''),
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate,
ua.created_at
FROM user_affiliates ua
JOIN users invitee ON invitee.id = ua.user_id
JOIN users inviter ON inviter.id = ua.inviter_id
JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = ua.inviter_id
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
`+where+`
GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateInviteRecord, 0)
for rows.Next() {
var item service.AffiliateInviteRecord
if err := rows.Scan(
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.AffCode,
&item.TotalRebate,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"inviter.email", "inviter.username", "invitee.email", "invitee.username",
"po.id::text", "po.out_trade_no", "po.payment_type", "po.status",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN payment_orders po ON po.id = ual.source_order_id
JOIN users invitee ON invitee.id = ual.source_user_id
JOIN users inviter ON inviter.id = ual.user_id
WHERE ual.action = 'accrue'
AND ual.source_order_id IS NOT NULL`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"order": "po.id",
"inviter": "inviter.email",
"invitee": "invitee.email",
"order_amount": "po.amount",
"pay_amount": "po.pay_amount",
"rebate_amount": "ual.amount",
"payment_type": "po.payment_type",
"order_status": "po.status",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT po.id,
po.out_trade_no,
ual.user_id,
COALESCE(inviter.email, ''),
COALESCE(inviter.username, ''),
ual.source_user_id,
COALESCE(invitee.email, ''),
COALESCE(invitee.username, ''),
po.amount::double precision,
po.pay_amount::double precision,
ual.amount::double precision,
po.payment_type,
po.status,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateRebateRecord, 0)
for rows.Next() {
var item service.AffiliateRebateRecord
if err := rows.Scan(
&item.OrderID,
&item.OutTradeNo,
&item.InviterID,
&item.InviterEmail,
&item.InviterUsername,
&item.InviteeID,
&item.InviteeEmail,
&item.InviteeUsername,
&item.OrderAmount,
&item.PayAmount,
&item.RebateAmount,
&item.PaymentType,
&item.OrderStatus,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) {
client := clientFromContext(ctx, r.client)
where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{
"u.email", "u.username", "u.id::text",
})
baseJoin := `
FROM user_affiliate_ledger ual
JOIN users u ON u.id = ual.user_id
WHERE ual.action = 'transfer'`
if where != "" {
where = strings.Replace(where, "WHERE ", " AND ", 1)
}
total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...)
if err != nil {
return nil, 0, err
}
orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{
"user": "u.email",
"amount": "ual.amount",
"balance_after": "ual.balance_after",
"available_quota_after": "ual.aff_quota_after",
"frozen_quota_after": "ual.aff_frozen_quota_after",
"history_quota_after": "ual.aff_history_quota_after",
"created_at": "ual.created_at",
}, "ual.created_at")
args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize)
rows, err := client.QueryContext(ctx, `
SELECT ual.id,
ual.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ual.amount::double precision,
ual.balance_after::double precision,
ual.aff_quota_after::double precision,
ual.aff_frozen_quota_after::double precision,
ual.aff_history_quota_after::double precision,
ual.created_at
`+baseJoin+where+`
`+orderBy+`
LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
items := make([]service.AffiliateTransferRecord, 0)
for rows.Next() {
var item service.AffiliateTransferRecord
var balanceAfter sql.NullFloat64
var availableQuotaAfter sql.NullFloat64
var frozenQuotaAfter sql.NullFloat64
var historyQuotaAfter sql.NullFloat64
if err := rows.Scan(
&item.LedgerID,
&item.UserID,
&item.UserEmail,
&item.Username,
&item.Amount,
&balanceAfter,
&availableQuotaAfter,
&frozenQuotaAfter,
&historyQuotaAfter,
&item.CreatedAt,
); err != nil {
return nil, 0, err
}
item.BalanceAfter = nullableFloat64Ptr(balanceAfter)
item.AvailableQuotaAfter = nullableFloat64Ptr(availableQuotaAfter)
item.FrozenQuotaAfter = nullableFloat64Ptr(frozenQuotaAfter)
item.HistoryQuotaAfter = nullableFloat64Ptr(historyQuotaAfter)
item.SnapshotAvailable = balanceAfter.Valid &&
availableQuotaAfter.Valid &&
frozenQuotaAfter.Valid &&
historyQuotaAfter.Valid
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var overview service.AffiliateUserOverview
var customRate float64
var hasCustomRate bool
if err := rows.Scan(
&overview.UserID,
&overview.Email,
&overview.Username,
&overview.AffCode,
&customRate,
&hasCustomRate,
&overview.InvitedCount,
&overview.RebatedInviteeCount,
&overview.AvailableQuota,
&overview.HistoryQuota,
); err != nil {
return nil, err
}
if hasCustomRate {
overview.RebateRatePercent = customRate
overview.RebateRateCustom = true
}
return &overview, rows.Err()
}
func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) {
clauses := make([]string, 0, 3)
args := make([]any, 0, 3)
if filter.StartAt != nil {
args = append(args, *filter.StartAt)
clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args)))
}
if filter.EndAt != nil {
args = append(args, *filter.EndAt)
clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args)))
}
search := strings.TrimSpace(filter.Search)
if search != "" && len(searchColumns) > 0 {
args = append(args, "%"+strings.ToLower(search)+"%")
parts := make([]string, 0, len(searchColumns))
for _, col := range searchColumns {
parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args)))
}
clauses = append(clauses, "("+strings.Join(parts, " OR ")+")")
}
if len(clauses) == 0 {
return "", args
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string {
column := sortColumns[filter.SortBy]
if column == "" {
column = fallbackColumn
}
direction := "DESC"
if !filter.SortDesc {
direction = "ASC"
}
return "ORDER BY " + column + " " + direction + " NULLS LAST"
}
func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
rows, err := client.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return 0, rows.Err()
}
var total int64
if err := rows.Scan(&total); err != nil {
return 0, err
}
return total, rows.Err()
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil { if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client()) return fn(ctx, tx.Client())
@ -516,6 +910,54 @@ func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID i
return balance, nil return balance, nil
} }
type affiliateTransferSnapshot struct {
BalanceAfter float64
AvailableQuotaAfter float64
FrozenQuotaAfter float64
HistoryQuotaAfter float64
}
func queryAffiliateTransferSnapshot(ctx context.Context, client affiliateQueryExecer, userID int64) (*affiliateTransferSnapshot, error) {
rows, err := client.QueryContext(ctx, `
SELECT u.balance::double precision,
ua.aff_quota::double precision,
ua.aff_frozen_quota::double precision,
ua.aff_history_quota::double precision
FROM users u
JOIN user_affiliates ua ON ua.user_id = u.id
WHERE u.id = $1
LIMIT 1`, userID)
if err != nil {
return nil, fmt.Errorf("query affiliate transfer snapshot: %w", err)
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUserNotFound
}
var snapshot affiliateTransferSnapshot
if err := rows.Scan(
&snapshot.BalanceAfter,
&snapshot.AvailableQuotaAfter,
&snapshot.FrozenQuotaAfter,
&snapshot.HistoryQuotaAfter,
); err != nil {
return nil, err
}
return &snapshot, rows.Err()
}
func nullableFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
return &v.Float64
}
func generateAffiliateCode() (string, error) { func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength) buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil { if _, err := rand.Read(buf); err != nil {
@ -674,6 +1116,13 @@ func nullableArg(v *float64) any {
return *v return *v
} }
func nullableInt64Arg(v *int64) any {
if v == nil {
return nil
}
return *v
}
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。 // ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
// //
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索" // 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索"

View File

@ -78,6 +78,26 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
ledgerCount := querySingleInt(t, txCtx, client, ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount) require.Equal(t, 1, ledgerCount)
rows, err := client.QueryContext(txCtx, `
SELECT amount::double precision,
balance_after::double precision,
aff_quota_after::double precision,
aff_frozen_quota_after::double precision,
aff_history_quota_after::double precision
FROM user_affiliate_ledger
WHERE user_id = $1 AND action = 'transfer'
LIMIT 1`, u.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected transfer ledger")
var amount, balanceAfter, quotaAfter, frozenAfter, historyAfter float64
require.NoError(t, rows.Scan(&amount, &balanceAfter, &quotaAfter, &frozenAfter, &historyAfter))
require.InDelta(t, 12.34, amount, 1e-9)
require.InDelta(t, 17.84, balanceAfter, 1e-9)
require.InDelta(t, 0.0, quotaAfter, 1e-9)
require.InDelta(t, 0.0, frozenAfter, 1e-9)
require.InDelta(t, 12.34, historyAfter, 1e-9)
} }
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the // TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
@ -125,7 +145,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter") require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0, nil)
require.NoError(t, err) require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true") require.True(t, applied, "AccrueQuota must report applied=true")

View File

@ -0,0 +1,28 @@
package repository
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) {
query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ")
require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)")
require.Contains(t, query, "frozen_until <= NOW()")
}
func TestAffiliateRecordQueriesUseLedgerAuditFields(t *testing.T) {
source, err := os.ReadFile("affiliate_repo.go")
require.NoError(t, err)
content := string(source)
require.Contains(t, content, "JOIN payment_orders po ON po.id = ual.source_order_id")
require.Contains(t, content, "ual.amount::double precision")
require.Contains(t, content, "ual.balance_after::double precision")
require.NotContains(t, content, "parseAffiliateRebateAmount")
require.NotContains(t, content, `"current_balance": "u.balance"`)
}

View File

@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates") affiliates := admin.Group("/affiliates")
{ {
affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords)
affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords)
affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords)
users := affiliates.Group("/users") users := affiliates.Group("/users")
{ {
users.GET("", h.Admin.Affiliate.ListUsers) users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
} }

View File

@ -0,0 +1,86 @@
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
func TestMergeBalanceHistoryCodesIncludesAffiliateTransfersByDefault(t *testing.T) {
t.Parallel()
now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
older := now.Add(-2 * time.Hour)
newer := now.Add(time.Hour)
usedBy := int64(10)
redeemCodes := []RedeemCode{
{
ID: 1,
Type: RedeemTypeBalance,
Value: 8,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &now,
CreatedAt: now,
},
{
ID: 2,
Type: RedeemTypeConcurrency,
Value: 1,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &older,
CreatedAt: older,
},
}
affiliateCodes := []RedeemCode{
{
ID: -20,
Type: RedeemTypeAffiliateBalance,
Value: 3.5,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &newer,
CreatedAt: newer,
},
}
got := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, pagination.PaginationParams{
Page: 1,
PageSize: 2,
})
require.Len(t, got, 2)
require.Equal(t, RedeemTypeAffiliateBalance, got[0].Type)
require.Equal(t, RedeemTypeBalance, got[1].Type)
}
func TestMergeBalanceHistoryCodesPaginatesAfterCombiningSources(t *testing.T) {
t.Parallel()
base := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC)
usedBy := int64(10)
at := func(hours int) *time.Time {
v := base.Add(time.Duration(hours) * time.Hour)
return &v
}
got := mergeBalanceHistoryCodes(
[]RedeemCode{
{ID: 1, Type: RedeemTypeBalance, UsedBy: &usedBy, UsedAt: at(4), CreatedAt: *at(4)},
{ID: 2, Type: RedeemTypeConcurrency, UsedBy: &usedBy, UsedAt: at(2), CreatedAt: *at(2)},
},
[]RedeemCode{
{ID: -3, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(3), CreatedAt: *at(3)},
{ID: -4, Type: RedeemTypeAffiliateBalance, UsedBy: &usedBy, UsedAt: at(1), CreatedAt: *at(1)},
},
pagination.PaginationParams{Page: 2, PageSize: 2},
)
require.Len(t, got, 2)
require.Equal(t, RedeemTypeConcurrency, got[0].Type)
require.Equal(t, int64(-4), got[1].ID)
}

View File

@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -973,16 +974,213 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
if codeType == RedeemTypeAffiliateBalance {
codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, 0, err
}
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, total, totalRecharged, nil
}
if codeType == "" {
return s.getAllUserBalanceHistory(ctx, userID, params)
}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
total := result.Total
// Aggregate total recharged amount (only once, regardless of type filter) // Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
return codes, result.Total, totalRecharged, nil return codes, total, totalRecharged, nil
}
func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) {
needed := params.Offset() + params.Limit()
if needed < params.Limit() {
needed = params.Limit()
}
redeemCodes, redeemTotal, err := s.listRedeemBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
affiliateCodes, affiliateTotal, err := s.listAffiliateBalanceHistoryForMerge(ctx, userID, needed)
if err != nil {
return nil, 0, 0, err
}
codes := mergeBalanceHistoryCodes(redeemCodes, affiliateCodes, params)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, redeemTotal + affiliateTotal, totalRecharged, nil
}
func (s *adminServiceImpl) listRedeemBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, "")
if err != nil {
return nil, 0, err
}
if result != nil {
total = result.Total
}
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistoryForMerge(ctx context.Context, userID int64, needed int) ([]RedeemCode, int64, error) {
if needed <= 0 {
return nil, 0, nil
}
var (
out []RedeemCode
total int64
)
for page := 1; len(out) < needed; page++ {
params := pagination.PaginationParams{Page: page, PageSize: 1000}
codes, currentTotal, err := s.listAffiliateBalanceHistory(ctx, userID, params)
if err != nil {
return nil, 0, err
}
total = currentTotal
out = append(out, codes...)
if len(codes) < params.Limit() || int64(len(out)) >= total {
break
}
}
if len(out) > needed {
out = out[:needed]
}
return out, total, nil
}
func (s *adminServiceImpl) listAffiliateBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return nil, 0, nil
}
rows, err := s.entClient.QueryContext(ctx, `
SELECT id,
amount::double precision,
created_at
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'
ORDER BY created_at DESC, id DESC
OFFSET $2
LIMIT $3`, userID, params.Offset(), params.Limit())
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
codes := make([]RedeemCode, 0, params.Limit())
for rows.Next() {
var id int64
var amount float64
var createdAt time.Time
if err := rows.Scan(&id, &amount, &createdAt); err != nil {
return nil, 0, err
}
usedBy := userID
usedAt := createdAt
codes = append(codes, RedeemCode{
ID: -id,
Code: fmt.Sprintf("AFF-%d", id),
Type: RedeemTypeAffiliateBalance,
Value: amount,
Status: StatusUsed,
UsedBy: &usedBy,
UsedAt: &usedAt,
CreatedAt: createdAt,
})
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
total, err := countAffiliateBalanceHistory(ctx, s.entClient, userID)
if err != nil {
return nil, 0, err
}
return codes, total, nil
}
func countAffiliateBalanceHistory(ctx context.Context, client *dbent.Client, userID int64) (int64, error) {
rows, err := client.QueryContext(ctx, `
SELECT COUNT(*)
FROM user_affiliate_ledger
WHERE user_id = $1
AND action = 'transfer'`, userID)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
var total sql.NullInt64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
if err := rows.Err(); err != nil {
return 0, err
}
if !total.Valid {
return 0, nil
}
return total.Int64, nil
}
func mergeBalanceHistoryCodes(redeemCodes, affiliateCodes []RedeemCode, params pagination.PaginationParams) []RedeemCode {
combined := append(append([]RedeemCode{}, redeemCodes...), affiliateCodes...)
sort.SliceStable(combined, func(i, j int) bool {
return redeemCodeHistoryTime(combined[i]).After(redeemCodeHistoryTime(combined[j]))
})
offset := params.Offset()
if offset >= len(combined) {
return []RedeemCode{}
}
end := offset + params.Limit()
if end > len(combined) {
end = len(combined)
}
return combined[offset:end]
}
func redeemCodeHistoryTime(code RedeemCode) time.Time {
if code.UsedAt != nil {
return *code.UsedAt
}
return code.CreatedAt
} }
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {

View File

@ -98,7 +98,7 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
@ -110,6 +110,10 @@ type AffiliateRepository interface {
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error)
ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error)
ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error)
GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error)
} }
// AffiliateAdminFilter 列表筛选条件 // AffiliateAdminFilter 列表筛选条件
@ -130,6 +134,76 @@ type AffiliateAdminEntry struct {
AffCount int `json:"aff_count"` AffCount int `json:"aff_count"`
} }
type AffiliateRecordFilter struct {
Search string
Page int
PageSize int
StartAt *time.Time
EndAt *time.Time
SortBy string
SortDesc bool
}
type AffiliateInviteRecord struct {
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
AffCode string `json:"aff_code"`
TotalRebate float64 `json:"total_rebate"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateRebateRecord struct {
OrderID int64 `json:"order_id"`
OutTradeNo string `json:"out_trade_no"`
InviterID int64 `json:"inviter_id"`
InviterEmail string `json:"inviter_email"`
InviterUsername string `json:"inviter_username"`
InviteeID int64 `json:"invitee_id"`
InviteeEmail string `json:"invitee_email"`
InviteeUsername string `json:"invitee_username"`
OrderAmount float64 `json:"order_amount"`
PayAmount float64 `json:"pay_amount"`
RebateAmount float64 `json:"rebate_amount"`
PaymentType string `json:"payment_type"`
OrderStatus string `json:"order_status"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateTransferRecord struct {
LedgerID int64 `json:"ledger_id"`
UserID int64 `json:"user_id"`
UserEmail string `json:"user_email"`
Username string `json:"username"`
Amount float64 `json:"amount"`
BalanceAfter *float64 `json:"balance_after,omitempty"`
AvailableQuotaAfter *float64 `json:"available_quota_after,omitempty"`
FrozenQuotaAfter *float64 `json:"frozen_quota_after,omitempty"`
HistoryQuotaAfter *float64 `json:"history_quota_after,omitempty"`
SnapshotAvailable bool `json:"snapshot_available"`
CurrentBalance float64 `json:"-"`
RemainingQuota float64 `json:"-"`
FrozenQuota float64 `json:"-"`
HistoryQuota float64 `json:"-"`
CreatedAt time.Time `json:"created_at"`
}
type AffiliateUserOverview struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
AffCode string `json:"aff_code"`
RebateRatePercent float64 `json:"rebate_rate_percent"`
RebateRateCustom bool `json:"-"`
InvitedCount int `json:"invited_count"`
RebatedInviteeCount int `json:"rebated_invitee_count"`
AvailableQuota float64 `json:"available_quota"`
HistoryQuota float64 `json:"history_quota"`
}
type AffiliateService struct { type AffiliateService struct {
repo AffiliateRepository repo AffiliateRepository
settingService *SettingService settingService *SettingService
@ -238,6 +312,10 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
} }
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
return s.AccrueInviteRebateForOrder(ctx, inviteeUserID, baseRechargeAmount, nil)
}
func (s *AffiliateService) AccrueInviteRebateForOrder(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64, sourceOrderID *int64) (float64, error) {
if s == nil || s.repo == nil { if s == nil || s.repo == nil {
return 0, nil return 0, nil
} }
@ -298,7 +376,7 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
} }
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours, sourceOrderID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -488,3 +566,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi
} }
return s.repo.ListUsersWithCustomSettings(ctx, filter) return s.repo.ListUsersWithCustomSettings(ctx, filter)
} }
func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter))
}
func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
overview, err := s.repo.GetAffiliateUserOverview(ctx, userID)
if err != nil {
return nil, err
}
if overview != nil {
if !overview.RebateRateCustom {
overview.RebateRatePercent = s.globalRebateRatePercent(ctx)
}
overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent)
}
return overview, nil
}
func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter {
if filter.Page <= 0 {
filter.Page = 1
}
if filter.PageSize <= 0 {
filter.PageSize = 20
}
if filter.PageSize > 100 {
filter.PageSize = 100
}
filter.Search = strings.TrimSpace(filter.Search)
filter.SortBy = strings.TrimSpace(filter.SortBy)
return filter
}

View File

@ -51,10 +51,11 @@ const (
// Redeem type constants // Redeem type constants
const ( const (
RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation RedeemTypeInvitation = domain.RedeemTypeInvitation
RedeemTypeAffiliateBalance = "affiliate_balance"
) )
// PromoCode status constants // PromoCode status constants

View File

@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
} }
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if ctx == nil {
return context.Background(), func() {}
}
if !stream { if !stream {
return ctx, func() {} return ctx, func() {}
} }
return context.WithoutCancel(ctx), func() {}
}
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil { if ctx == nil {
return context.Background(), func() {} return context.Background(), func() {}
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type upstreamContextTestKey string
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens) require.Equal(t, 7, result.usage.OutputTokens)
} }
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
upstreamCtx, release := detachUpstreamContext(parent)
defer release()
cancel()
require.NoError(t, upstreamCtx.Err())
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
}

View File

@ -3,13 +3,16 @@ package service
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -18,6 +21,51 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAICompatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
type openAICompatBlockingReadCloser struct {
data []byte
offset int
closed chan struct{}
closeOnce sync.Once
}
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
return &openAICompatBlockingReadCloser{
data: data,
closed: make(chan struct{}),
}
}
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
if r.offset < len(r.data) {
n := copy(p, r.data[r.offset:])
r.offset += n
return n, nil
}
<-r.closed
return 0, io.EOF
}
func (r *openAICompatBlockingReadCloser) Close() error {
r.closeOnce.Do(func() {
close(r.closed)
})
return nil
}
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
t.Parallel() t.Parallel()
@ -228,3 +276,242 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
} }
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
}
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou
return nil return nil
} }
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) {
counter := &openAI403CounterResetStub{} counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter) rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
rateLimitService: rateLimitSvc, billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
svc.rateLimitService = rateLimitSvc
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{}, Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage_reset_403",
Model: "gpt-5.1",
},
APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2001},
Account: &Account{ID: 777, Platform: PlatformOpenAI}, Account: &Account{ID: 777, Platform: PlatformOpenAI},
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls) require.Equal(t, []int64{777}, counter.resetCalls)
require.Equal(t, 1, usageRepo.calls)
} }

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
chunks := apicompat.ResponsesEventToChatChunks(&event, state) chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks { if !clientDisconnected {
sse, err := apicompat.ChatChunkToSSE(chunk) for _, chunk := range chunks {
if err != nil { sse, err := apicompat.ChatChunkToSSE(chunk)
logger.L().Warn("openai chat_completions stream: failed to marshal chunk", if err != nil {
zap.Error(err), logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai chat_completions stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(chunks) > 0 { if len(chunks) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks { for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk) sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
} }
} }
// Send [DONE] sentinel // Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck if !clientDisconnected {
c.Writer.Flush() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
}
if !clientDisconnected {
c.Writer.Flush()
}
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// Determine keepalive interval // Determine keepalive interval
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
// No keepalive: fast synchronous path // No keepalive: fast synchronous path
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// With keepalive: goroutine + channel + select // With keepalive: goroutine + channel + select
@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai chat_completions stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
logger.L().Info("openai chat_completions stream: client disconnected during keepalive", logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }

View File

@ -1,13 +1,36 @@
package service package service
import ( import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAIChatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestNormalizeResponsesRequestServiceTier(t *testing.T) { func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel() t.Parallel()
@ -73,3 +96,242 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Empty(t, tier) require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists()) require.False(t, gjson.GetBytes(body, "service_tier").Exists())
} }
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
}
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer func() {
require.NoError(t, upstreamStream.Close())
}()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
}, nil }, nil
} }
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.incomplete", "response.failed":
return true
default:
return false
}
}
func isOpenAICompatDoneSentinelLine(line string) bool {
payload, ok := extractOpenAISSEDataLine(line)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
resp *http.Response,
logPrefix string,
requestID string,
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
acc := apicompat.NewBufferedResponseAccumulator()
var usage OpenAIUsage
if resp == nil || resp.Body == nil {
return nil, usage, acc, errors.New("upstream response body is nil")
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var timeoutCh <-chan time.Time
var timeoutTimer *time.Timer
resetTimeout := func() {
if streamInterval <= 0 {
return
}
if timeoutTimer == nil {
timeoutTimer = time.NewTimer(streamInterval)
timeoutCh = timeoutTimer.C
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
timeoutTimer.Reset(streamInterval)
}
stopTimeout := func() {
if timeoutTimer == nil {
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
}
resetTimeout()
defer stopTimeout()
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
go func() {
defer close(events)
for scanner.Scan() {
select {
case events <- scanEvent{line: scanner.Text()}:
case <-done:
return
}
}
if err := scanner.Err(); err != nil {
select {
case events <- scanEvent{err: err}:
case <-done:
}
}
}()
defer close(done)
for {
select {
case ev, ok := <-events:
if !ok {
return nil, usage, acc, nil
}
resetTimeout()
if ev.err != nil {
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
logger.L().Warn(logPrefix+": read error",
zap.Error(ev.err),
zap.String("request_id", requestID),
)
}
return nil, usage, acc, ev.err
}
if isOpenAICompatDoneSentinelLine(ev.line) {
return nil, usage, acc, nil
}
payload, ok := extractOpenAISSEDataLine(ev.line)
if !ok || payload == "" {
continue
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn(logPrefix+": failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
acc.ProcessEvent(&event)
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
if event.Response.Usage != nil {
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
}
return event.Response, usage, acc, nil
}
case <-timeoutCh:
_ = resp.Body.Close()
logger.L().Warn(logPrefix+": data interval timeout",
zap.String("request_id", requestID),
zap.Duration("interval", streamInterval),
)
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
}
}
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream, // handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client. // converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel // When StreamKeepaliveInterval is configured, it uses a goroutine + channel
@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// processDataLine handles a single "data: ..." SSE line from upstream. // processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool { processDataLine := func(payload string) bool {
if firstChunk { if firstChunk {
firstChunk = false firstChunk = false
@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
// Convert to Anthropic events // Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state) events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events { if !clientDisconnected {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) for _, evt := range events {
if err != nil { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
logger.L().Warn("openai messages stream: failed to marshal event", if err != nil {
zap.Error(err), logger.L().Warn("openai messages stream: failed to marshal event",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai messages stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(events) > 0 { if len(events) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
// finalizeStream sends any remaining Anthropic events and returns the result. // finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
for _, evt := range finalEvents { for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai messages stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
}
if !clientDisconnected {
c.Writer.Flush()
} }
c.Writer.Flush()
} }
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// ── Determine keepalive interval ── // ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// ── No keepalive: fast synchronous path (no goroutine overhead) ── // ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// ── With keepalive: goroutine + channel + select ── // ── With keepalive: goroutine + channel + select ──
@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
// Upstream closed // Upstream closed
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if processDataLine(payload) {
return resultWithUsage(), nil return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai messages stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
logger.L().Info("openai messages stream: client disconnected during keepalive", logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }
@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
}, },
}) })
} }
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
if usage == nil {
return OpenAIUsage{}
}
result := OpenAIUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
}
if usage.InputTokensDetails != nil {
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
}
return result
}

View File

@ -186,6 +186,56 @@ func max(a, b int) int {
return b return b
} }
func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_zero_usage",
Usage: OpenAIUsage{},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}},
User: &User{ID: 2000},
Account: &Account{ID: 3000, Type: AccountTypeAPIKey},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID)
require.Zero(t, usageRepo.lastLog.InputTokens)
require.Zero(t, usageRepo.lastLog.OutputTokens)
require.Zero(t, usageRepo.lastLog.CacheCreationTokens)
require.Zero(t, usageRepo.lastLog.CacheReadTokens)
require.Zero(t, usageRepo.lastLog.ImageOutputTokens)
require.Zero(t, usageRepo.lastLog.ImageCount)
require.Zero(t, usageRepo.lastLog.InputCost)
require.Zero(t, usageRepo.lastLog.OutputCost)
require.Zero(t, usageRepo.lastLog.TotalCost)
require.Zero(t, usageRepo.lastLog.ActualCost)
require.NotNil(t, billingRepo.lastCmd)
require.Zero(t, billingRepo.lastCmd.BalanceCost)
require.Zero(t, billingRepo.lastCmd.SubscriptionCost)
require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost)
require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost)
require.Zero(t, billingRepo.lastCmd.AccountQuotaCost)
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11) groupID := int64(11)
groupRate := 1.4 groupRate := 1.4

View File

@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false httpInvalidEncryptedContentRetryTried := false
for { for {
// Build upstream request // Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err return nil, err
} }
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
} }
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
return nil
}
apiKey := input.APIKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account

View File

@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
var usage OpenAIUsage var usage OpenAIUsage
imageCount := parsed.N imageCount := parsed.N
var firstTokenMs *int var firstTokenMs *int
if parsed.Stream { if parsed.Stream && isEventStreamResponse(resp.Header) {
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil { if err != nil {
return nil, err return nil, err
@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
usage := OpenAIUsage{} usage := OpenAIUsage{}
imageCount := 0 imageCount := 0
var firstTokenMs *int var firstTokenMs *int
var fallbackBody bytes.Buffer
fallbackBytes := int64(0)
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
seenSSEData := false
fallbackTooLarge := false
for { for {
line, err := reader.ReadBytes('\n') line, err := reader.ReadBytes('\n')
@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
} }
flusher.Flush() flusher.Flush()
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
dataBytes := []byte(data) if data != "" && data != "[DONE]" {
mergeOpenAIUsage(&usage, dataBytes) seenSSEData = true
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { fallbackBody.Reset()
imageCount = count fallbackBytes = 0
dataBytes := []byte(data)
mergeOpenAIUsage(&usage, dataBytes)
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
imageCount = count
}
}
} else if !seenSSEData && !fallbackTooLarge {
fallbackBytes += int64(len(line))
if fallbackBytes <= fallbackLimit {
_, _ = fallbackBody.Write(line)
} else {
fallbackTooLarge = true
fallbackBody.Reset()
} }
} }
} }
@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
return OpenAIUsage{}, 0, firstTokenMs, err return OpenAIUsage{}, 0, firstTokenMs, err
} }
} }
if !seenSSEData && fallbackBody.Len() > 0 {
body := bytes.TrimSpace(fallbackBody.Bytes())
if len(body) > 0 {
mergeOpenAIUsage(&usage, body)
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
imageCount = count
}
}
}
return usage, imageCount, firstTokenMs, nil return usage, imageCount, firstTokenMs, nil
} }
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
return count
}
if len(body) == 0 || !gjson.ValidBytes(body) {
return 0
}
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
return count
}
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
return count
}
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
return 0
}
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
return 1
}
return 0
}
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
if dst == nil { if dst == nil {
return return

View File

@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
} }
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req_img_stream_json"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 7,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 21, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.ImageOutputTokens)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{
cfg: &config.Config{},
httpUpstream: &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
},
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
},
},
}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
account := &Account{
ID: 8,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 18, result.Usage.OutputTokens)
require.Equal(t, 8, result.Usage.ImageOutputTokens)
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
}
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
} }
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
require.Contains(t, string(upstream.lastBody), `"stream":true`) require.Contains(t, string(upstream.lastBody), `"stream":true`)
} }
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -394,7 +394,8 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
return nil return nil
} }
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) sourceOrderID := o.ID
rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID)
if err != nil { if err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(), "error": err.Error(),

View File

@ -0,0 +1,85 @@
-- 邀请返利流水补充订单关联和转余额快照。
-- 这些字段只用于审计展示;历史旧流水无法可靠反推的字段保持 NULL避免把当前状态误展示为历史状态。
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS source_order_id BIGINT NULL REFERENCES payment_orders(id) ON DELETE SET NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_frozen_quota_after DECIMAL(20,8) NULL;
ALTER TABLE user_affiliate_ledger
ADD COLUMN IF NOT EXISTS aff_history_quota_after DECIMAL(20,8) NULL;
COMMENT ON COLUMN user_affiliate_ledger.source_order_id IS '产生该返利流水的充值订单;转余额或无法可靠回填的历史数据为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.balance_after IS '邀请返利转余额后的用户余额快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_quota_after IS '邀请返利转余额后的可用返利额度快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_frozen_quota_after IS '邀请返利转余额后的冻结返利额度快照;无法取得时为 NULL';
COMMENT ON COLUMN user_affiliate_ledger.aff_history_quota_after IS '邀请返利转余额后的历史返利总额快照;无法取得时为 NULL';
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_source_order_id
ON user_affiliate_ledger(source_order_id)
WHERE source_order_id IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_rebate_lookup
ON user_affiliate_ledger(action, source_order_id, user_id, source_user_id, created_at)
WHERE action = 'accrue';
-- 尽力回填 PR #2169 合并后、该迁移前已经产生的返利流水。
-- 只有在同一订单只能匹配到一条返利流水时才回填,避免把多笔同额流水错误绑定到订单。
WITH rebate_audits AS (
SELECT po.id AS order_id,
po.user_id AS invitee_user_id,
invitee_aff.inviter_id,
rebate_detail.rebate_amount,
pal.created_at AS audit_created_at
FROM payment_audit_logs pal
CROSS JOIN LATERAL (
SELECT substring(
pal.detail
FROM '"rebateAmount"[[:space:]]*:[[:space:]]*(-?[0-9]+(\.[0-9]+)?)'
)::numeric AS rebate_amount
) rebate_detail
JOIN payment_orders po ON po.id::text = pal.order_id
JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id
WHERE pal.action = 'AFFILIATE_REBATE_APPLIED'
AND rebate_detail.rebate_amount IS NOT NULL
),
ranked_matches AS (
SELECT ual.id AS ledger_id,
ra.order_id,
COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count,
COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count,
ROW_NUMBER() OVER (
PARTITION BY ual.id
ORDER BY ABS(EXTRACT(EPOCH FROM (ual.created_at - ra.audit_created_at))), ra.order_id
) AS ledger_rank
FROM rebate_audits ra
JOIN user_affiliate_ledger ual
ON ual.action = 'accrue'
AND ual.source_order_id IS NULL
AND ual.user_id = ra.inviter_id
AND ual.source_user_id = ra.invitee_user_id
AND ABS(ual.amount - ra.rebate_amount) < 0.00000001
AND ual.created_at BETWEEN ra.audit_created_at - INTERVAL '10 minutes'
AND ra.audit_created_at + INTERVAL '10 minutes'
)
UPDATE user_affiliate_ledger ual
SET source_order_id = ranked_matches.order_id,
updated_at = NOW()
FROM ranked_matches
WHERE ual.id = ranked_matches.ledger_id
AND ranked_matches.order_match_count = 1
AND ranked_matches.ledger_match_count = 1
AND ranked_matches.ledger_rank = 1
AND NOT EXISTS (
SELECT 1
FROM user_affiliate_ledger existing
WHERE existing.source_order_id = ranked_matches.order_id
AND existing.action = 'accrue'
);

View File

@ -127,3 +127,18 @@ func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
require.Contains(t, sql, "oidc_connect_enabled") require.Contains(t, sql, "oidc_connect_enabled")
require.Contains(t, sql, "'false'") require.Contains(t, sql, "'false'")
} }
func TestMigration134AddsAffiliateLedgerAuditFieldsWithoutJSONCast(t *testing.T) {
content, err := FS.ReadFile("134_affiliate_ledger_audit_snapshots.sql")
require.NoError(t, err)
sql := string(content)
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS source_order_id BIGINT")
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS balance_after DECIMAL(20,8)")
require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS aff_quota_after DECIMAL(20,8)")
require.Contains(t, sql, "substring(")
require.Contains(t, sql, `"rebateAmount"`)
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ra.order_id) AS order_match_count")
require.Contains(t, sql, "COUNT(*) OVER (PARTITION BY ual.id) AS ledger_match_count")
require.NotContains(t, sql, "detail::jsonb")
}

View File

@ -23,6 +23,72 @@ export interface ListAffiliateUsersParams {
search?: string search?: string
} }
export interface ListAffiliateRecordsParams {
page?: number
page_size?: number
search?: string
start_at?: string
end_at?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
timezone?: string
}
export interface AffiliateInviteRecord {
inviter_id: number
inviter_email: string
inviter_username: string
invitee_id: number
invitee_email: string
invitee_username: string
aff_code: string
total_rebate: number
created_at: string
}
export interface AffiliateRebateRecord {
order_id: number
out_trade_no: string
inviter_id: number
inviter_email: string
inviter_username: string
invitee_id: number
invitee_email: string
invitee_username: string
order_amount: number
pay_amount: number
rebate_amount: number
payment_type: string
order_status: string
created_at: string
}
export interface AffiliateTransferRecord {
ledger_id: number
user_id: number
user_email: string
username: string
amount: number
balance_after?: number | null
available_quota_after?: number | null
frozen_quota_after?: number | null
history_quota_after?: number | null
snapshot_available: boolean
created_at: string
}
export interface AffiliateUserOverview {
user_id: number
email: string
username: string
aff_code: string
rebate_rate_percent: number
invited_count: number
rebated_invitee_count: number
available_quota: number
history_quota: number
}
export interface UpdateAffiliateUserRequest { export interface UpdateAffiliateUserRequest {
aff_code?: string aff_code?: string
aff_rebate_rate_percent?: number | null aff_rebate_rate_percent?: number | null
@ -97,12 +163,68 @@ export async function batchSetRate(
return data return data
} }
function recordParams(params: ListAffiliateRecordsParams = {}) {
return {
page: params.page ?? 1,
page_size: params.page_size ?? 20,
search: params.search ?? '',
start_at: params.start_at || undefined,
end_at: params.end_at || undefined,
sort_by: params.sort_by || undefined,
sort_order: params.sort_order || undefined,
timezone: params.timezone || undefined,
}
}
export async function listInviteRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateInviteRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateInviteRecord>>(
'/admin/affiliates/invites',
{ params: recordParams(params) },
)
return data
}
export async function listRebateRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateRebateRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateRebateRecord>>(
'/admin/affiliates/rebates',
{ params: recordParams(params) },
)
return data
}
export async function listTransferRecords(
params: ListAffiliateRecordsParams = {},
): Promise<PaginatedResponse<AffiliateTransferRecord>> {
const { data } = await apiClient.get<PaginatedResponse<AffiliateTransferRecord>>(
'/admin/affiliates/transfers',
{ params: recordParams(params) },
)
return data
}
export async function getUserOverview(
userId: number,
): Promise<AffiliateUserOverview> {
const { data } = await apiClient.get<AffiliateUserOverview>(
`/admin/affiliates/users/${userId}/overview`,
)
return data
}
export const affiliatesAPI = { export const affiliatesAPI = {
listUsers, listUsers,
lookupUsers, lookupUsers,
updateUserSettings, updateUserSettings,
clearUserSettings, clearUserSettings,
batchSetRate, batchSetRate,
listInviteRecords,
listRebateRecords,
listTransferRecords,
getUserOverview,
} }
export default affiliatesAPI export default affiliatesAPI

View File

@ -249,7 +249,7 @@ export interface BalanceHistoryResponse extends PaginatedResponse<BalanceHistory
* @param id - User ID * @param id - User ID
* @param page - Page number * @param page - Page number
* @param pageSize - Items per page * @param pageSize - Items per page
* @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription) * @param type - Optional type filter (balance, affiliate_balance, admin_balance, concurrency, admin_concurrency, subscription)
* @returns Paginated balance history with total_recharged * @returns Paginated balance history with total_recharged
*/ */
export async function getUserBalanceHistory( export async function getUserBalanceHistory(

View File

@ -196,6 +196,7 @@ const totalPages = computed(() => Math.ceil(total.value / pageSize) || 1)
const typeOptions = computed(() => [ const typeOptions = computed(() => [
{ value: '', label: t('admin.users.allTypes') }, { value: '', label: t('admin.users.allTypes') },
{ value: 'balance', label: t('admin.users.typeBalance') }, { value: 'balance', label: t('admin.users.typeBalance') },
{ value: 'affiliate_balance', label: t('admin.users.typeAffiliateBalance') },
{ value: 'admin_balance', label: t('admin.users.typeAdminBalance') }, { value: 'admin_balance', label: t('admin.users.typeAdminBalance') },
{ value: 'concurrency', label: t('admin.users.typeConcurrency') }, { value: 'concurrency', label: t('admin.users.typeConcurrency') },
{ value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') }, { value: 'admin_concurrency', label: t('admin.users.typeAdminConcurrency') },
@ -235,7 +236,7 @@ const loadHistory = async (page: number) => {
const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency' const isAdminType = (type: string) => type === 'admin_balance' || type === 'admin_concurrency'
// Helper: check if balance type (includes admin_balance) // Helper: check if balance type (includes admin_balance)
const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' const isBalanceType = (type: string) => type === 'balance' || type === 'admin_balance' || type === 'affiliate_balance'
// Helper: check if subscription type // Helper: check if subscription type
const isSubscriptionType = (type: string) => type === 'subscription' const isSubscriptionType = (type: string) => type === 'subscription'
@ -291,6 +292,8 @@ const getItemTitle = (item: BalanceHistoryItem) => {
switch (item.type) { switch (item.type) {
case 'balance': case 'balance':
return t('redeem.balanceAddedRedeem') return t('redeem.balanceAddedRedeem')
case 'affiliate_balance':
return t('redeem.balanceAddedAffiliate')
case 'admin_balance': case 'admin_balance':
return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin') return item.value >= 0 ? t('redeem.balanceAddedAdmin') : t('redeem.balanceDeductedAdmin')
case 'concurrency': case 'concurrency':

View File

@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => {
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
{
path: '/admin/affiliates',
label: t('nav.affiliateManagement'),
icon: UsersIcon,
hideInSimpleMode: true,
expandOnly: true,
featureFlag: flagAffiliate,
children: [
{ path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon },
{ path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon },
{ path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon },
],
},
{ {
path: '/admin/orders', path: '/admin/orders',
label: t('nav.orderManagement'), label: t('nav.orderManagement'),

View File

@ -347,6 +347,10 @@ export default {
usage: 'Usage', usage: 'Usage',
redeem: 'Redeem', redeem: 'Redeem',
affiliate: 'Affiliate Rebates', affiliate: 'Affiliate Rebates',
affiliateManagement: 'Affiliate Rebates',
affiliateInviteRecords: 'Invite Records',
affiliateRebateRecords: 'Rebate Records',
affiliateTransferRecords: 'Transfer Records',
profile: 'Profile', profile: 'Profile',
users: 'Users', users: 'Users',
groups: 'Groups', groups: 'Groups',
@ -1046,6 +1050,7 @@ export default {
recentActivity: 'Recent Activity', recentActivity: 'Recent Activity',
historyWillAppear: 'Your redemption history will appear here', historyWillAppear: 'Your redemption history will appear here',
balanceAddedRedeem: 'Balance Added (Redeem)', balanceAddedRedeem: 'Balance Added (Redeem)',
balanceAddedAffiliate: 'Balance Added (Affiliate Transfer)',
balanceAddedAdmin: 'Balance Added (Admin)', balanceAddedAdmin: 'Balance Added (Admin)',
balanceDeductedAdmin: 'Balance Deducted (Admin)', balanceDeductedAdmin: 'Balance Deducted (Admin)',
concurrencyAddedRedeem: 'Concurrency Added (Redeem)', concurrencyAddedRedeem: 'Concurrency Added (Redeem)',
@ -1635,6 +1640,49 @@ export default {
} }
}, },
affiliates: {
invitesDescription: 'View site-wide inviter and invitee relationships',
rebatesDescription: 'View recharge orders that generated affiliate rebates',
transfersDescription: 'View affiliate quota transfers into account balance',
errors: {
loadFailed: 'Failed to load affiliate records'
},
records: {
search: 'Search',
searchPlaceholder: 'Email, username, user ID, or order number',
startAt: 'Start date',
endAt: 'End date',
inviter: 'Inviter',
invitee: 'Invitee',
user: 'User',
affCode: 'Invite Code',
order: 'Order',
totalRebate: 'Total Rebate',
orderAmount: 'Top-up Amount',
payAmount: 'Paid Amount',
rebateAmount: 'Rebate Amount',
paymentType: 'Payment Method',
orderStatus: 'Order Status',
transferAmount: 'Transfer Amount',
balanceAfter: 'Balance After',
availableQuotaAfter: 'Available After',
frozenQuotaAfter: 'Frozen After',
historyQuotaAfter: 'Historical Rebate After',
invitedAt: 'Invited At',
rebatedAt: 'Rebated At',
transferredAt: 'Transferred At'
},
overview: {
title: 'Affiliate User Overview',
affCode: 'Invite Code',
rebateRate: 'Rebate Rate',
invitedCount: 'Invited Users',
rebatedInviteeCount: 'Rebated Invitees',
availableQuota: 'Available Quota',
historyQuota: 'Historical Rebate'
}
},
// Users // Users
users: { users: {
title: 'User Management', title: 'User Management',
@ -1787,6 +1835,7 @@ export default {
noBalanceHistory: 'No records found for this user', noBalanceHistory: 'No records found for this user',
allTypes: 'All Types', allTypes: 'All Types',
typeBalance: 'Balance (Redeem)', typeBalance: 'Balance (Redeem)',
typeAffiliateBalance: 'Balance (Affiliate Transfer)',
typeAdminBalance: 'Balance (Admin)', typeAdminBalance: 'Balance (Admin)',
typeConcurrency: 'Concurrency (Redeem)', typeConcurrency: 'Concurrency (Redeem)',
typeAdminConcurrency: 'Concurrency (Admin)', typeAdminConcurrency: 'Concurrency (Admin)',

View File

@ -347,6 +347,10 @@ export default {
usage: '使用记录', usage: '使用记录',
redeem: '兑换', redeem: '兑换',
affiliate: '邀请返利', affiliate: '邀请返利',
affiliateManagement: '邀请返利',
affiliateInviteRecords: '邀请记录',
affiliateRebateRecords: '返利记录',
affiliateTransferRecords: '提取记录',
profile: '个人资料', profile: '个人资料',
users: '用户管理', users: '用户管理',
groups: '分组管理', groups: '分组管理',
@ -1050,6 +1054,7 @@ export default {
recentActivity: '最近活动', recentActivity: '最近活动',
historyWillAppear: '您的兑换历史将显示在这里', historyWillAppear: '您的兑换历史将显示在这里',
balanceAddedRedeem: '余额充值(兑换)', balanceAddedRedeem: '余额充值(兑换)',
balanceAddedAffiliate: '余额充值(返利转入)',
balanceAddedAdmin: '余额充值(管理员)', balanceAddedAdmin: '余额充值(管理员)',
balanceDeductedAdmin: '余额扣除(管理员)', balanceDeductedAdmin: '余额扣除(管理员)',
concurrencyAddedRedeem: '并发增加(兑换)', concurrencyAddedRedeem: '并发增加(兑换)',
@ -1656,6 +1661,49 @@ export default {
} }
}, },
affiliates: {
invitesDescription: '查看全站邀请关系和被邀请用户累计返利',
rebatesDescription: '查看每一笔产生返利的充值订单',
transfersDescription: '查看返利额度转入账户余额的提取流水',
errors: {
loadFailed: '加载邀请返利记录失败'
},
records: {
search: '搜索',
searchPlaceholder: '邮箱、用户名、用户 ID、订单号',
startAt: '开始日期',
endAt: '结束日期',
inviter: '邀请人',
invitee: '被邀请人',
user: '用户',
affCode: '邀请码',
order: '订单',
totalRebate: '累计返利',
orderAmount: '充值金额',
payAmount: '支付金额',
rebateAmount: '返利金额',
paymentType: '支付方式',
orderStatus: '订单状态',
transferAmount: '提取金额',
balanceAfter: '提取后余额',
availableQuotaAfter: '提取后可提',
frozenQuotaAfter: '提取后冻结',
historyQuotaAfter: '提取后历史返利',
invitedAt: '邀请时间',
rebatedAt: '返利时间',
transferredAt: '提取时间'
},
overview: {
title: '用户返利概览',
affCode: '邀请码',
rebateRate: '返利比例',
invitedCount: '邀请人数',
rebatedInviteeCount: '已产生返利人数',
availableQuota: '可提余额',
historyQuota: '历史返利'
}
},
// Users Management // Users Management
users: { users: {
title: '用户管理', title: '用户管理',
@ -1844,6 +1892,7 @@ export default {
noBalanceHistory: '暂无变动记录', noBalanceHistory: '暂无变动记录',
allTypes: '全部类型', allTypes: '全部类型',
typeBalance: '余额(兑换码)', typeBalance: '余额(兑换码)',
typeAffiliateBalance: '余额(返利转入)',
typeAdminBalance: '余额(管理员调整)', typeAdminBalance: '余额(管理员调整)',
typeConcurrency: '并发(兑换码)', typeConcurrency: '并发(兑换码)',
typeAdminConcurrency: '并发(管理员调整)', typeAdminConcurrency: '并发(管理员调整)',

View File

@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.usage.description' descriptionKey: 'admin.usage.description'
} }
}, },
{
path: '/admin/affiliates',
redirect: '/admin/affiliates/invites'
},
{
path: '/admin/affiliates/invites',
name: 'AdminAffiliateInvites',
component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Invite Records',
titleKey: 'nav.affiliateInviteRecords',
descriptionKey: 'admin.affiliates.invitesDescription'
}
},
{
path: '/admin/affiliates/rebates',
name: 'AdminAffiliateRebates',
component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Rebate Records',
titleKey: 'nav.affiliateRebateRecords',
descriptionKey: 'admin.affiliates.rebatesDescription'
}
},
{
path: '/admin/affiliates/transfers',
name: 'AdminAffiliateTransfers',
component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Affiliate Transfer Records',
titleKey: 'nav.affiliateTransferRecords',
descriptionKey: 'admin.affiliates.transfersDescription'
}
},
// ==================== Payment Admin Routes ==================== // ==================== Payment Admin Routes ====================

View File

@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="invites" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>

View File

@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="rebates" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>

View File

@ -0,0 +1,407 @@
<template>
<AppLayout>
<TablePageLayout>
<template #filters>
<div class="flex flex-wrap items-center gap-3">
<div class="relative w-full md:w-80">
<Icon name="search" size="md" class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400" />
<input v-model="filters.search" type="text" class="input pl-10" :placeholder="t('admin.affiliates.records.searchPlaceholder')" @input="debounceLoad" />
</div>
<input v-model="filters.start_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.startAt')" @change="reloadFromFirstPage" />
<input v-model="filters.end_at" type="date" class="input w-full sm:w-44" :title="t('admin.affiliates.records.endAt')" @change="reloadFromFirstPage" />
<button class="btn btn-secondary px-2 md:px-3" :disabled="loading" :title="t('common.refresh')" @click="loadRecords">
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
</div>
</template>
<template #table>
<DataTable
:columns="columns"
:data="records"
:loading="loading"
:server-side-sort="true"
default-sort-key="created_at"
default-sort-order="desc"
:sort-storage-key="sortStorageKey"
@sort="handleSort"
>
<template #cell-inviter="{ row }">
<UserCell
:id="row.inviter_id"
:email="row.inviter_email"
:username="row.inviter_username"
:clickable="props.type !== 'transfers'"
@open="openUserOverview"
/>
</template>
<template #cell-invitee="{ row }">
<UserCell
:id="row.invitee_id"
:email="row.invitee_email"
:username="row.invitee_username"
:clickable="props.type !== 'transfers'"
@open="openUserOverview"
/>
</template>
<template #cell-user="{ row }">
<UserCell
:id="row.user_id"
:email="row.user_email"
:username="row.username"
:clickable="true"
@open="openUserOverview"
/>
</template>
<template #cell-aff_code="{ row }">
<span class="font-mono text-sm text-gray-700 dark:text-gray-300">{{ row.aff_code || '-' }}</span>
</template>
<template #cell-order="{ row }">
<div class="space-y-0.5">
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ row.order_id }}</div>
<div class="max-w-56 truncate text-sm text-gray-500 dark:text-dark-400">{{ row.out_trade_no }}</div>
</div>
</template>
<template #cell-payment_type="{ row }">
{{ t('payment.methods.' + row.payment_type, row.payment_type || '-') }}
</template>
<template #cell-order_status="{ row }">
<OrderStatusBadge :status="row.order_status" />
</template>
<template #cell-total_rebate="{ row }">
<AmountText :value="row.total_rebate" />
</template>
<template #cell-order_amount="{ row }">
<AmountText :value="row.order_amount" />
</template>
<template #cell-pay_amount="{ row }">
<span class="text-sm text-gray-900 dark:text-white">¥{{ formatAmount(row.pay_amount) }}</span>
</template>
<template #cell-rebate_amount="{ row }">
<AmountText :value="row.rebate_amount" strong />
</template>
<template #cell-amount="{ row }">
<AmountText :value="row.amount" strong />
</template>
<template #cell-balance_after="{ row }">
<NullableAmountText :value="row.balance_after" />
</template>
<template #cell-available_quota_after="{ row }">
<NullableAmountText :value="row.available_quota_after" />
</template>
<template #cell-frozen_quota_after="{ row }">
<NullableAmountText :value="row.frozen_quota_after" />
</template>
<template #cell-history_quota_after="{ row }">
<NullableAmountText :value="row.history_quota_after" />
</template>
<template #cell-created_at="{ row }">
<span class="text-sm text-gray-700 dark:text-gray-300">{{ formatDateTime(row.created_at) }}</span>
</template>
</DataTable>
</template>
<template #pagination>
<Pagination
v-if="pagination.total > 0"
:page="pagination.page"
:total="pagination.total"
:page-size="pagination.page_size"
@update:page="handlePageChange"
@update:pageSize="handlePageSizeChange"
/>
</template>
</TablePageLayout>
<BaseDialog
:show="overviewDialog"
:title="t('admin.affiliates.overview.title')"
width="normal"
@close="overviewDialog = false"
>
<div v-if="overviewLoading" class="flex justify-center py-8">
<div class="h-6 w-6 animate-spin rounded-full border-2 border-primary-500 border-t-transparent"></div>
</div>
<div v-else-if="selectedOverview" class="space-y-4">
<div class="rounded-lg border border-gray-100 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800">
<div class="font-mono text-sm text-gray-900 dark:text-white">#{{ selectedOverview.user_id }}</div>
<div class="mt-1 text-sm font-medium text-gray-900 dark:text-white">{{ selectedOverview.email || '-' }}</div>
<div class="mt-0.5 text-sm text-gray-500 dark:text-dark-400">{{ selectedOverview.username || '-' }}</div>
</div>
<div class="grid gap-3 sm:grid-cols-2">
<OverviewStat :label="t('admin.affiliates.overview.affCode')" :value="selectedOverview.aff_code || '-'" mono />
<OverviewStat :label="t('admin.affiliates.overview.rebateRate')" :value="formatPercent(selectedOverview.rebate_rate_percent)" />
<OverviewStat :label="t('admin.affiliates.overview.invitedCount')" :value="String(selectedOverview.invited_count)" />
<OverviewStat :label="t('admin.affiliates.overview.rebatedInviteeCount')" :value="String(selectedOverview.rebated_invitee_count)" />
<OverviewStat :label="t('admin.affiliates.overview.availableQuota')" :value="'$' + formatAmount(selectedOverview.available_quota)" />
<OverviewStat :label="t('admin.affiliates.overview.historyQuota')" :value="'$' + formatAmount(selectedOverview.history_quota)" />
</div>
</div>
</BaseDialog>
</AppLayout>
</template>
<script setup lang="ts">
import { computed, defineComponent, h, onMounted, reactive, ref, type PropType } from 'vue'
import { useI18n } from 'vue-i18n'
import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import DataTable from '@/components/common/DataTable.vue'
import Pagination from '@/components/common/Pagination.vue'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue'
import type { Column } from '@/components/common/types'
import { useAppStore } from '@/stores/app'
import { affiliatesAPI, type AffiliateInviteRecord, type AffiliateRebateRecord, type AffiliateTransferRecord, type AffiliateUserOverview, type ListAffiliateRecordsParams } from '@/api/admin/affiliates'
import type { PaginatedResponse } from '@/types'
import { extractI18nErrorMessage } from '@/utils/apiError'
import { formatDateTime as formatDisplayDateTime } from '@/utils/format'
type RecordType = 'invites' | 'rebates' | 'transfers'
type AffiliateRecord = AffiliateInviteRecord | AffiliateRebateRecord | AffiliateTransferRecord
const props = defineProps<{
type: RecordType
}>()
const { t } = useI18n()
const appStore = useAppStore()
const loading = ref(false)
const records = ref<AffiliateRecord[]>([])
const filters = reactive({ search: '', start_at: '', end_at: '' })
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
const overviewDialog = ref(false)
const overviewLoading = ref(false)
const selectedOverview = ref<AffiliateUserOverview | null>(null)
let debounceTimer: ReturnType<typeof setTimeout> | null = null
const columns = computed<Column[]>(() => {
if (props.type === 'invites') {
return [
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
{ key: 'aff_code', label: t('admin.affiliates.records.affCode'), sortable: true },
{ key: 'total_rebate', label: t('admin.affiliates.records.totalRebate'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.invitedAt'), sortable: true },
]
}
if (props.type === 'rebates') {
return [
{ key: 'order', label: t('admin.affiliates.records.order'), sortable: true },
{ key: 'inviter', label: t('admin.affiliates.records.inviter'), sortable: true },
{ key: 'invitee', label: t('admin.affiliates.records.invitee'), sortable: true },
{ key: 'order_amount', label: t('admin.affiliates.records.orderAmount'), sortable: true },
{ key: 'pay_amount', label: t('admin.affiliates.records.payAmount'), sortable: true },
{ key: 'rebate_amount', label: t('admin.affiliates.records.rebateAmount') },
{ key: 'payment_type', label: t('admin.affiliates.records.paymentType'), sortable: true },
{ key: 'order_status', label: t('admin.affiliates.records.orderStatus'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.rebatedAt'), sortable: true },
]
}
return [
{ key: 'user', label: t('admin.affiliates.records.user'), sortable: true },
{ key: 'amount', label: t('admin.affiliates.records.transferAmount'), sortable: true },
{ key: 'balance_after', label: t('admin.affiliates.records.balanceAfter'), sortable: true },
{ key: 'available_quota_after', label: t('admin.affiliates.records.availableQuotaAfter'), sortable: true },
{ key: 'frozen_quota_after', label: t('admin.affiliates.records.frozenQuotaAfter'), sortable: true },
{ key: 'history_quota_after', label: t('admin.affiliates.records.historyQuotaAfter'), sortable: true },
{ key: 'created_at', label: t('admin.affiliates.records.transferredAt'), sortable: true },
]
})
const sortStorageKey = computed(() => `admin-affiliate-${props.type}-table-sort`)
function loadInitialSortState(): { sort_by: string; sort_order: 'asc' | 'desc' } {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
try {
const raw = localStorage.getItem(sortStorageKey.value)
if (!raw) return fallback
const parsed = JSON.parse(raw) as { key?: string; order?: string }
const key = typeof parsed.key === 'string' ? parsed.key : ''
if (!columns.value.some((column) => column.key === key && column.sortable)) return fallback
return {
sort_by: key,
sort_order: parsed.order === 'asc' ? 'asc' : 'desc',
}
} catch {
return fallback
}
}
const sortState = reactive(loadInitialSortState())
function userTimezone(): string {
try {
return Intl.DateTimeFormat().resolvedOptions().timeZone
} catch {
return 'UTC'
}
}
function buildParams(): ListAffiliateRecordsParams {
return {
page: pagination.page,
page_size: pagination.page_size,
search: filters.search.trim() || undefined,
start_at: filters.start_at || undefined,
end_at: filters.end_at || undefined,
sort_by: sortState.sort_by,
sort_order: sortState.sort_order,
timezone: userTimezone(),
}
}
async function fetchRecords(params: ListAffiliateRecordsParams): Promise<PaginatedResponse<AffiliateRecord>> {
if (props.type === 'invites') {
return affiliatesAPI.listInviteRecords(params)
}
if (props.type === 'rebates') {
return affiliatesAPI.listRebateRecords(params)
}
return affiliatesAPI.listTransferRecords(params)
}
async function loadRecords() {
loading.value = true
try {
const res = await fetchRecords(buildParams())
records.value = res.items || []
pagination.total = res.total || 0
} catch (error) {
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
} finally {
loading.value = false
}
}
function debounceLoad() {
if (debounceTimer) clearTimeout(debounceTimer)
debounceTimer = setTimeout(() => reloadFromFirstPage(), 300)
}
function reloadFromFirstPage() {
pagination.page = 1
void loadRecords()
}
function handlePageChange(page: number) {
pagination.page = page
void loadRecords()
}
function handlePageSizeChange(size: number) {
pagination.page_size = size
pagination.page = 1
void loadRecords()
}
function handleSort(key: string, order: 'asc' | 'desc') {
sortState.sort_by = key
sortState.sort_order = order
pagination.page = 1
void loadRecords()
}
function formatAmount(value: number | null | undefined): string {
return Number(value || 0).toFixed(2)
}
function formatPercent(value: number | null | undefined): string {
const rounded = Math.round(Number(value || 0) * 100) / 100
return `${Number.isInteger(rounded) ? rounded.toString() : rounded.toString()}%`
}
function formatDateTime(value: string | null | undefined): string {
return value ? formatDisplayDateTime(value) : '-'
}
async function openUserOverview(userId: number) {
if (!userId) return
overviewDialog.value = true
overviewLoading.value = true
selectedOverview.value = null
try {
selectedOverview.value = await affiliatesAPI.getUserOverview(userId)
} catch (error) {
overviewDialog.value = false
appStore.showError(extractI18nErrorMessage(error, t, 'admin.affiliates.errors', t('common.error')))
} finally {
overviewLoading.value = false
}
}
const UserCell = defineComponent({
props: {
id: { type: Number, required: true },
email: { type: String, default: '' },
username: { type: String, default: '' },
clickable: { type: Boolean, default: false },
},
emits: ['open'],
setup(cellProps, { emit }) {
return () => h('div', { class: 'space-y-0.5' }, [
h('div', { class: 'font-mono text-sm text-gray-900 dark:text-white' }, `#${cellProps.id}`),
h(cellProps.clickable ? 'button' : 'div', {
class: cellProps.clickable
? 'max-w-56 truncate text-left text-sm font-medium text-primary-600 hover:text-primary-700 hover:underline dark:text-primary-400 dark:hover:text-primary-300'
: 'max-w-56 truncate text-sm text-gray-700 dark:text-gray-300',
type: cellProps.clickable ? 'button' : undefined,
onClick: cellProps.clickable ? () => emit('open', cellProps.id) : undefined,
}, cellProps.email || '-'),
h('div', { class: 'max-w-56 truncate text-sm text-gray-500 dark:text-dark-400' }, cellProps.username || '-'),
])
},
})
const AmountText = defineComponent({
props: {
value: { type: Number, default: 0 },
strong: { type: Boolean, default: false },
},
setup(amountProps) {
return () => h('span', {
class: amountProps.strong
? 'text-sm font-semibold text-emerald-600 dark:text-emerald-400'
: 'text-sm text-gray-900 dark:text-white',
}, `$${formatAmount(amountProps.value)}`)
},
})
const NullableAmountText = defineComponent({
props: {
value: { type: Number as PropType<number | null | undefined>, default: null },
},
setup(amountProps) {
return () => {
const value = amountProps.value
if (value === null || value === undefined) {
return h('span', { class: 'text-sm text-gray-400 dark:text-dark-500' }, '-')
}
return h(AmountText, { value })
}
},
})
const OverviewStat = defineComponent({
props: {
label: { type: String, required: true },
value: { type: String, required: true },
mono: { type: Boolean, default: false },
},
setup(statProps) {
return () => h('div', { class: 'rounded-lg border border-gray-100 bg-white p-3 dark:border-dark-700 dark:bg-dark-900' }, [
h('div', { class: 'text-sm text-gray-500 dark:text-dark-400' }, statProps.label),
h('div', {
class: statProps.mono
? 'mt-1 font-mono text-base font-semibold text-gray-900 dark:text-white'
: 'mt-1 text-base font-semibold text-gray-900 dark:text-white',
}, statProps.value),
])
},
})
onMounted(() => {
void loadRecords()
})
</script>

View File

@ -0,0 +1,7 @@
<template>
<AdminAffiliateRecordsTable type="transfers" />
</template>
<script setup lang="ts">
import AdminAffiliateRecordsTable from './AdminAffiliateRecordsTable.vue'
</script>